diff --git a/go.sum b/go.sum index 11eb8e7d1..cd4de014e 100644 --- a/go.sum +++ b/go.sum @@ -179,6 +179,7 @@ github.com/imdario/mergo v0.3.8 h1:CGgOkSJeqMRmt0D9XLWExdT4m4F1vd3FV3VPt+0VxkQ= github.com/imdario/mergo v0.3.8/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= github.com/ipfs/go-cid v0.0.3 h1:UIAh32wymBpStoe83YCzwVQQ5Oy/H0FdxvUS6DJDzms= github.com/ipfs/go-cid v0.0.3/go.mod h1:GHWU/WuQdMPmIosc4Yn1bcCT7dSeX4lBafM7iqUPQvM= +github.com/ipfs/go-ipfs-util v0.0.1 h1:Wz9bL2wB2YBJqggkA4dD7oSmqB4cAnpNbGrlHJulv50= github.com/ipfs/go-ipfs-util v0.0.1/go.mod h1:spsl5z8KUnrve+73pOhSVZND1SIxPW5RyBCNzQxlJBc= github.com/ipfs/testground v0.0.0-20200121194104-fd53f27ef027/go.mod h1:gawN4GBDQfpi0HWwh25ZXK/bSxZVoMEjdlpSmjJt2t8= github.com/jbenet/go-cienv v0.1.0/go.mod h1:TqNnHUmJgXau0nCzC7kXWeotg3J9W34CUv5Djy1+FlA= diff --git a/pkg/dockermanager/dockermanager.go b/pkg/dockermanager/dockermanager.go index 24437461a..0c7e6a24a 100644 --- a/pkg/dockermanager/dockermanager.go +++ b/pkg/dockermanager/dockermanager.go @@ -109,11 +109,9 @@ func (dm *Manager) Manage( managers := make(map[string]workerHandle) defer func() { - // cancel the remaining managers - for _, m := range managers { - m.cancel() - } // wait for the running managers to exit + // They'll get canceled when we close the main context (deferred + // below). for _, m := range managers { <-m.done } diff --git a/pkg/sidecar/docker_instance.go b/pkg/sidecar/docker_instance.go index 2a4c374af..b3f919c83 100644 --- a/pkg/sidecar/docker_instance.go +++ b/pkg/sidecar/docker_instance.go @@ -275,7 +275,7 @@ func (d *DockerInstanceManager) manageContainer(ctx context.Context, container * } } } - return NewInstance(runenv, info.Config.Hostname, network, newDockerLogs(logs)) + return NewInstance(ctx, runenv, info.Config.Hostname, network, newDockerLogs(logs)) } type dockerLink struct { diff --git a/pkg/sidecar/instance.go b/pkg/sidecar/instance.go index a58a7130a..38bc96fda 100644 --- a/pkg/sidecar/instance.go +++ b/pkg/sidecar/instance.go @@ -43,9 +43,9 @@ type Logs interface { } // NewInstance constructs a new test instance handle. -func NewInstance(runenv *runtime.RunEnv, hostname string, network Network, logs Logs) (*Instance, error) { +func NewInstance(ctx context.Context, runenv *runtime.RunEnv, hostname string, network Network, logs Logs) (*Instance, error) { // Get a redis reader/writer. - watcher, writer, err := sync.WatcherWriter(runenv) + watcher, writer, err := sync.WatcherWriter(ctx, runenv) if err != nil { return nil, fmt.Errorf("during sync.WatcherWriter: %w", err) } diff --git a/pkg/sidecar/k8s_instance.go b/pkg/sidecar/k8s_instance.go index 3d228b524..bb5534d59 100644 --- a/pkg/sidecar/k8s_instance.go +++ b/pkg/sidecar/k8s_instance.go @@ -85,7 +85,11 @@ func (d *K8sInstanceManager) Close() error { func (d *K8sInstanceManager) manageContainer(ctx context.Context, container *dockermanager.Container) (inst *Instance, err error) { // TODO: sidecar is racing to modify container network with CNI and pod getting ready // we should probably adjust this function to be called when a pod is in `1/1 Ready` state, and not just listen on the docker socket - time.Sleep(20 * time.Second) + select { + case <-time.After(20 * time.Second): + case <-ctx.Done(): + return nil, ctx.Err() + } // Get the state/config of the cluster info, err := container.Inspect(ctx) @@ -245,7 +249,7 @@ func (d *K8sInstanceManager) manageContainer(ctx context.Context, container *doc } } - return NewInstance(runenv, info.Config.Hostname, network, newDockerLogs(logs)) + return NewInstance(ctx, runenv, info.Config.Hostname, network, newDockerLogs(logs)) } type k8sLink struct { diff --git a/pkg/sidecar/sidecar.go b/pkg/sidecar/sidecar.go index 2a253ce64..f48b24286 100644 --- a/pkg/sidecar/sidecar.go +++ b/pkg/sidecar/sidecar.go @@ -77,7 +77,7 @@ func Run(runnerName string, resultPath string) error { // Wait for all the sidecars to enter the "network-initialized" state. const netInitState = "network-initialized" - if _, err = instance.Writer.SignalEntry(netInitState); err != nil { + if _, err = instance.Writer.SignalEntry(ctx, netInitState); err != nil { return fmt.Errorf("failed to signal network ready: %w", err) } instance.S().Infof("waiting for all networks to be ready") @@ -93,22 +93,16 @@ func Run(runnerName string, resultPath string) error { // Now let the test case tell us how to configure the network. subtree := sync.NetworkSubtree(instance.Hostname) networkChanges := make(chan *sync.NetworkConfig, 16) - closeSub, err := instance.Watcher.Subscribe(subtree, networkChanges) - if err != nil { + if err := instance.Watcher.Subscribe(ctx, subtree, networkChanges); err != nil { return fmt.Errorf("failed to subscribe to network changes: %s", err) } - defer func() { - if err := closeSub(); err != nil { - instance.S().Warnf("failed to close sub: %s", err) - } - }() for cfg := range networkChanges { instance.S().Infow("applying network change", "network", cfg) if err := instance.Network.ConfigureNetwork(ctx, cfg); err != nil { return fmt.Errorf("failed to update network %s: %w", cfg.Network, err) } if cfg.State != "" { - _, err := instance.Writer.SignalEntry(cfg.State) + _, err := instance.Writer.SignalEntry(ctx, cfg.State) if err != nil { return fmt.Errorf( "failed to signal network state change %s: %w", diff --git a/plans/bitswap-tuning/test/transfer.go b/plans/bitswap-tuning/test/transfer.go index 96b5bfd45..ce98203d8 100644 --- a/plans/bitswap-tuning/test/transfer.go +++ b/plans/bitswap-tuning/test/transfer.go @@ -39,7 +39,7 @@ func Transfer(runenv *runtime.RunEnv) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - watcher, writer := sync.MustWatcherWriter(runenv) + watcher, writer := sync.MustWatcherWriter(ctx, runenv) /// --- Tear down defer func() { @@ -64,7 +64,7 @@ func Transfer(runenv *runtime.RunEnv) error { defer node.Close() // Get sequence number of this host - seq, err := writer.Write(sync.PeerSubtree, host.InfoFromHost(node.Host)) + seq, err := writer.Write(ctx, sync.PeerSubtree, host.InfoFromHost(node.Host)) if err != nil { return err } @@ -93,37 +93,38 @@ func Transfer(runenv *runtime.RunEnv) error { } // Inform other nodes of the root CID - if _, err = writer.Write(RootCidSubtree, &rootCid); err != nil { + if _, err = writer.Write(ctx, RootCidSubtree, &rootCid); err != nil { return fmt.Errorf("Failed to get Redis Sync RootCidSubtree %w", err) } } else if isLeech { // Get the root CID from a seed rootCidCh := make(chan *cid.Cid, 1) - cancelRootCidSub, err := watcher.Subscribe(RootCidSubtree, rootCidCh) - if err != nil { + sctx, cancelRootCidSub := context.WithCancel(ctx) + if err := watcher.Subscribe(sctx, RootCidSubtree, rootCidCh); err != nil { return fmt.Errorf("Failed to subscribe to RootCidSubtree %w", err) } // Note: only need to get the root CID from one seed - it should be the // same on all seeds (seed data is generated from repeatable random // sequence) - select { - case rootCidPtr := <-rootCidCh: - cancelRootCidSub() - rootCid = *rootCidPtr - case <-time.After(timeout): - cancelRootCidSub() + rootCidPtr, ok := <-rootCidCh + cancelRootCidSub() + if !ok { return fmt.Errorf("no root cid in %d seconds", timeout/time.Second) } + rootCid = *rootCidPtr } // Get addresses of all peers peerCh := make(chan *peer.AddrInfo) - cancelSub, err := watcher.Subscribe(sync.PeerSubtree, peerCh) - addrInfos, err := utils.AddrInfosFromChan(peerCh, runenv.TestInstanceCount, timeout) + sctx, cancelSub := context.WithCancel(ctx) + if err := watcher.Subscribe(sctx, sync.PeerSubtree, peerCh); err != nil { + return err + } + addrInfos, err := utils.AddrInfosFromChan(peerCh, runenv.TestInstanceCount) if err != nil { cancelSub() - return err + return fmt.Errorf("no addrs in %d seconds", timeout/time.Second) } cancelSub() diff --git a/plans/bitswap-tuning/utils/net.go b/plans/bitswap-tuning/utils/net.go index a36be120a..1b1fab626 100644 --- a/plans/bitswap-tuning/utils/net.go +++ b/plans/bitswap-tuning/utils/net.go @@ -30,7 +30,7 @@ func SetupNetwork(ctx context.Context, runenv *runtime.RunEnv, watcher *sync.Wat latency := time.Duration(runenv.IntParam("latency_ms")) * time.Millisecond bandwidth := runenv.IntParam("bandwidth_mb") - writer.Write(sync.NetworkSubtree(hostname), &sync.NetworkConfig{ + _, err = writer.Write(ctx, sync.NetworkSubtree(hostname), &sync.NetworkConfig{ Network: "default", Enable: true, Default: sync.LinkShape{ @@ -39,6 +39,9 @@ func SetupNetwork(ctx context.Context, runenv *runtime.RunEnv, watcher *sync.Wat }, State: "network-configured", }) + if err != nil { + return err + } err = <-watcher.Barrier(ctx, "network-configured", int64(runenv.TestInstanceCount)) if err != nil { diff --git a/plans/bitswap-tuning/utils/peers.go b/plans/bitswap-tuning/utils/peers.go index fd940f40a..b7bbf1754 100644 --- a/plans/bitswap-tuning/utils/peers.go +++ b/plans/bitswap-tuning/utils/peers.go @@ -4,22 +4,19 @@ import ( "bytes" "context" "fmt" - "time" "github.com/libp2p/go-libp2p-core/peer" host "github.com/libp2p/go-libp2p-host" ) -func AddrInfosFromChan(peerCh chan *peer.AddrInfo, count int, timeout time.Duration) ([]peer.AddrInfo, error) { +func AddrInfosFromChan(peerCh chan *peer.AddrInfo, count int) ([]peer.AddrInfo, error) { var ais []peer.AddrInfo for i := 1; i <= count; i++ { - select { - case ai := <-peerCh: - ais = append(ais, *ai) - - case <-time.After(timeout): - return nil, fmt.Errorf("no new peers in %d seconds", timeout/time.Second) + ai, ok := <-peerCh + if !ok { + return ais, fmt.Errorf("subscription closed") } + ais = append(ais, *ai) } return ais, nil } diff --git a/plans/bitswap-tuning/utils/state.go b/plans/bitswap-tuning/utils/state.go index 0ac05f9af..769961d4a 100644 --- a/plans/bitswap-tuning/utils/state.go +++ b/plans/bitswap-tuning/utils/state.go @@ -12,15 +12,11 @@ func SignalAndWaitForAll(ctx context.Context, instanceCount int, stateName strin doneCh := watcher.Barrier(ctx, state, int64(instanceCount)) // Signal we've entered the state. - _, err := writer.SignalEntry(state) + _, err := writer.SignalEntry(ctx, state) if err != nil { return err } // Wait until all others have signalled. - if err = <-doneCh; err != nil { - return err - } - - return nil + return <-doneCh } diff --git a/plans/dht/test/common.go b/plans/dht/test/common.go index ef132cb2f..37d8b5561 100644 --- a/plans/dht/test/common.go +++ b/plans/dht/test/common.go @@ -126,7 +126,7 @@ func SetupNetwork(ctx context.Context, runenv *runtime.RunEnv, watcher *sync.Wat return err } - writer.Write(sync.NetworkSubtree(hostname), &sync.NetworkConfig{ + writer.Write(ctx, sync.NetworkSubtree(hostname), &sync.NetworkConfig{ Network: "default", Enable: true, Default: sync.LinkShape{ @@ -168,13 +168,13 @@ func Setup(ctx context.Context, runenv *runtime.RunEnv, watcher *sync.Watcher, w id := node.ID() runenv.Message("I am %s with addrs: %v", id, node.Addrs()) - if seq, err = writer.Write(sync.PeerSubtree, host.InfoFromHost(node)); err != nil { + if seq, err = writer.Write(ctx, sync.PeerSubtree, host.InfoFromHost(node)); err != nil { return nil, nil, nil, seq, fmt.Errorf("failed to write peer subtree in sync service: %w", err) } peerCh := make(chan *peer.AddrInfo, 16) - cancelSub, err := watcher.Subscribe(sync.PeerSubtree, peerCh) - if err != nil { + sctx, cancelSub := context.WithCancel(ctx) + if err := watcher.Subscribe(sctx, sync.PeerSubtree, peerCh); err != nil { return nil, nil, nil, seq, err } defer cancelSub() @@ -183,15 +183,14 @@ func Setup(ctx context.Context, runenv *runtime.RunEnv, watcher *sync.Watcher, w peers := make([]peer.AddrInfo, 0, runenv.TestInstanceCount) // Grab list of other peers that are available for this run. for i := 0; i < runenv.TestInstanceCount; i++ { - select { - case ai := <-peerCh: - if ai.ID == id { - continue - } - peers = append(peers, *ai) - case <-ctx.Done(): + ai, ok := <-peerCh + if !ok { return nil, nil, nil, seq, fmt.Errorf("no new peers in %d seconds", opts.Timeout/time.Second) } + if ai.ID == id { + continue + } + peers = append(peers, *ai) } sort.Slice(peers, func(i, j int) bool { @@ -240,7 +239,7 @@ func Bootstrap(ctx context.Context, runenv *runtime.RunEnv, watcher *sync.Watche } }() // Announce ourself as a bootstrap node. - if _, err := writer.Write(BootstrapSubtree, host.InfoFromHost(dht.Host())); err != nil { + if _, err := writer.Write(ctx, BootstrapSubtree, host.InfoFromHost(dht.Host())); err != nil { return err } // NOTE: If we start restricting the network, don't restrict @@ -389,23 +388,24 @@ func Bootstrap(ctx context.Context, runenv *runtime.RunEnv, watcher *sync.Watche // get all bootstrap peers. func getBootstrappers(ctx context.Context, runenv *runtime.RunEnv, watcher *sync.Watcher, opts *SetupOpts) ([]peer.AddrInfo, error) { + // cancel the sub + ctx, cancel := context.WithCancel(ctx) + defer cancel() + peerCh := make(chan *peer.AddrInfo, opts.NBootstrap) - cancelSub, err := watcher.Subscribe(BootstrapSubtree, peerCh) - if err != nil { + if err := watcher.Subscribe(ctx, BootstrapSubtree, peerCh); err != nil { return nil, err } - defer cancelSub() // TODO: remove this if it becomes too much coordination effort. peers := make([]peer.AddrInfo, opts.NBootstrap) // Grab list of other peers that are available for this run. for i := 0; i < opts.NBootstrap; i++ { - select { - case ai := <-peerCh: - peers[i] = *ai - case <-ctx.Done(): - return nil, fmt.Errorf("timed out waiting for bootstrappers") + ai, ok := <-peerCh + if !ok { + return peers, fmt.Errorf("timed out waiting for bootstrappers") } + peers[i] = *ai } runenv.Message("got all bootstrappers: %d", len(peers)) return peers, nil @@ -468,7 +468,7 @@ func Sync( doneCh := watcher.Barrier(ctx, state, int64(runenv.TestInstanceCount)) // Signal we're in the same state. - _, err := writer.SignalEntry(state) + _, err := writer.SignalEntry(ctx, state) if err != nil { return err } diff --git a/plans/dht/test/find_peers.go b/plans/dht/test/find_peers.go index 4902720de..c4cd4cf5b 100644 --- a/plans/dht/test/find_peers.go +++ b/plans/dht/test/find_peers.go @@ -26,7 +26,7 @@ func FindPeers(runenv *runtime.RunEnv) error { ctx, cancel := context.WithTimeout(context.Background(), opts.Timeout) defer cancel() - watcher, writer := sync.MustWatcherWriter(runenv) + watcher, writer := sync.MustWatcherWriter(ctx, runenv) defer watcher.Close() defer writer.Close() diff --git a/plans/dht/test/find_providers.go b/plans/dht/test/find_providers.go index 6e4bfad5e..e9486908a 100644 --- a/plans/dht/test/find_providers.go +++ b/plans/dht/test/find_providers.go @@ -29,7 +29,7 @@ func FindProviders(runenv *runtime.RunEnv) error { ctx, cancel := context.WithTimeout(context.Background(), opts.Timeout) defer cancel() - watcher, writer := sync.MustWatcherWriter(runenv) + watcher, writer := sync.MustWatcherWriter(ctx, runenv) defer watcher.Close() defer writer.Close() diff --git a/plans/network/main.go b/plans/network/main.go index 278a06f28..727a163e7 100644 --- a/plans/network/main.go +++ b/plans/network/main.go @@ -30,7 +30,7 @@ func run(runenv *runtime.RunEnv) error { } runenv.Message("before sync.MustWatcherWriter") - watcher, writer := sync.MustWatcherWriter(runenv) + watcher, writer := sync.MustWatcherWriter(ctx, runenv) defer watcher.Close() defer writer.Close() @@ -68,7 +68,7 @@ func run(runenv *runtime.RunEnv) error { } runenv.Message("before writer config") - _, err = writer.Write(sync.NetworkSubtree(hostname), &config) + _, err = writer.Write(ctx, sync.NetworkSubtree(hostname), &config) if err != nil { return err } @@ -92,7 +92,7 @@ func run(runenv *runtime.RunEnv) error { // Get a sequence number runenv.Message("get a sequence number") - seq, err := writer.Write(&sync.Subtree{ + seq, err := writer.Write(ctx, &sync.Subtree{ GroupKey: "ip-allocation", PayloadType: reflect.TypeOf(""), KeyFunc: func(val interface{}) string { @@ -129,7 +129,7 @@ func run(runenv *runtime.RunEnv) error { } logging.S().Debug("before writing changed ip config to redis") - _, err = writer.Write(sync.NetworkSubtree(hostname), &config) + _, err = writer.Write(ctx, sync.NetworkSubtree(hostname), &config) if err != nil { return err } @@ -222,7 +222,7 @@ func run(runenv *runtime.RunEnv) error { state := sync.State("ping-pong-" + test) // Don't reconfigure the network until we're done with the first test. - writer.SignalEntry(state) + writer.SignalEntry(ctx, state) err = <-watcher.Barrier(ctx, state, int64(runenv.TestInstanceCount)) if err != nil { return err @@ -238,7 +238,7 @@ func run(runenv *runtime.RunEnv) error { config.State = "latency-reduced" logging.S().Debug("writing new config with latency reduced") - _, err = writer.Write(sync.NetworkSubtree(hostname), &config) + _, err = writer.Write(ctx, sync.NetworkSubtree(hostname), &config) if err != nil { return err } diff --git a/sdk/sync/common.go b/sdk/sync/common.go index 7d00f924d..b263b1d42 100644 --- a/sdk/sync/common.go +++ b/sdk/sync/common.go @@ -1,6 +1,7 @@ package sync import ( + "context" "encoding/json" "fmt" "net" @@ -28,7 +29,7 @@ const ( // // TODO: source redis URL from environment variables. The Redis host and port // will be wired in by Nomad/Swarm. -func redisClient(runenv *runtime.RunEnv) (client *redis.Client, err error) { +func redisClient(ctx context.Context, runenv *runtime.RunEnv) (client *redis.Client, err error) { var ( host = os.Getenv(EnvRedisHost) port = os.Getenv(EnvRedisPort) @@ -43,7 +44,7 @@ func redisClient(runenv *runtime.RunEnv) (client *redis.Client, err error) { // Fall back to attempting to use `host.docker.internal` which // is only available in macOS and Windows. for _, h := range []string{RedisHostname, HostHostname} { - if addrs, err := net.LookupHost(h); err == nil && len(addrs) > 0 { + if addrs, err := net.DefaultResolver.LookupHost(ctx, h); err == nil && len(addrs) > 0 { host = h break } @@ -69,12 +70,12 @@ func redisClient(runenv *runtime.RunEnv) (client *redis.Client, err error) { client = redis.NewClient(opts) // PING redis to make sure we're alive. - return client, client.Ping().Err() + return client, client.WithContext(ctx).Ping().Err() } // MustWatcherWriter proxies to WatcherWriter, panicking if an error occurs. -func MustWatcherWriter(runenv *runtime.RunEnv) (*Watcher, *Writer) { - watcher, writer, err := WatcherWriter(runenv) +func MustWatcherWriter(ctx context.Context, runenv *runtime.RunEnv) (*Watcher, *Writer) { + watcher, writer, err := WatcherWriter(ctx, runenv) if err != nil { panic(err) } @@ -83,13 +84,13 @@ func MustWatcherWriter(runenv *runtime.RunEnv) (*Watcher, *Writer) { // WatcherWriter creates a Watcher and a Writer object associated with this test // run's sync tree. -func WatcherWriter(runenv *runtime.RunEnv) (*Watcher, *Writer, error) { - watcher, err := NewWatcher(runenv) +func WatcherWriter(ctx context.Context, runenv *runtime.RunEnv) (*Watcher, *Writer, error) { + watcher, err := NewWatcher(ctx, runenv) if err != nil { return nil, nil, err } - writer, err := NewWriter(runenv) + writer, err := NewWriter(ctx, runenv) if err != nil { return nil, nil, err } diff --git a/sdk/sync/redis_test.go b/sdk/sync/redis_test.go index e1c338cb5..b5a04225b 100644 --- a/sdk/sync/redis_test.go +++ b/sdk/sync/redis_test.go @@ -29,7 +29,7 @@ func ensureRedis(t *testing.T) (close func()) { // Try to obtain a client; if this fails, we'll attempt to start a redis // instance. - client, err := redisClient(runenv) + client, err := redisClient(context.Background(), runenv) if err == nil { return func() {} } @@ -42,7 +42,7 @@ func ensureRedis(t *testing.T) (close func()) { time.Sleep(1 * time.Second) // Try to obtain a client again. - if client, err = redisClient(runenv); err != nil { + if client, err = redisClient(context.Background(), runenv); err != nil { t.Fatalf("failed to obtain redis client despite starting instance: %v", err) } defer client.Close() @@ -55,19 +55,22 @@ func ensureRedis(t *testing.T) (close func()) { } func TestWatcherWriter(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + close := ensureRedis(t) defer close() runenv := runtime.RandomRunEnv() - watcher, err := NewWatcher(runenv) + watcher, err := NewWatcher(ctx, runenv) if err != nil { t.Fatal(err) } defer watcher.Close() peersCh := make(chan *peer.AddrInfo, 16) - cancel, err := watcher.Subscribe(PeerSubtree, peersCh) + err = watcher.Subscribe(ctx, PeerSubtree, peersCh) if err != nil { t.Fatal(err) } @@ -77,7 +80,7 @@ func TestWatcherWriter(t *testing.T) { t.Fatal(err) } - writer, err := NewWriter(runenv) + writer, err := NewWriter(ctx, runenv) if err != nil { t.Fatal(err) } @@ -92,7 +95,7 @@ func TestWatcherWriter(t *testing.T) { t.Fatal(err) } - writer.Write(PeerSubtree, ai) + writer.Write(ctx, PeerSubtree, ai) if err != nil { t.Fatal(err) } @@ -107,23 +110,23 @@ func TestWatcherWriter(t *testing.T) { } func TestBarrier(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + close := ensureRedis(t) defer close() runenv := runtime.RandomRunEnv() - watcher, writer := MustWatcherWriter(runenv) + watcher, writer := MustWatcherWriter(ctx, runenv) defer watcher.Close() defer writer.Close() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - state := State("yoda") ch := watcher.Barrier(ctx, state, 10) for i := 1; i <= 10; i++ { - if curr, err := writer.SignalEntry(state); err != nil { + if curr, err := writer.SignalEntry(ctx, state); err != nil { t.Fatal(err) } else if curr != int64(i) { t.Fatalf("expected current count to be: %d; was: %d", i, curr) @@ -145,19 +148,21 @@ func TestWatchInexistentKeyThenWrite(t *testing.T) { subtree = randomTestSubtree() ) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + closeRedis := ensureRedis(t) defer closeRedis() - watcher, writer := MustWatcherWriter(runenv) + watcher, writer := MustWatcherWriter(ctx, runenv) defer watcher.Close() defer writer.Close() ch := make(chan *string, 128) - subCancel, err := watcher.Subscribe(subtree, ch) + err := watcher.Subscribe(ctx, subtree, ch) if err != nil { t.Fatal(err) } - defer subCancel() doneCh := make(chan struct{}) go func() { @@ -182,21 +187,23 @@ func TestWriteAllBeforeWatch(t *testing.T) { subtree = randomTestSubtree() ) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + closeRedis := ensureRedis(t) defer closeRedis() - watcher, writer := MustWatcherWriter(runenv) + watcher, writer := MustWatcherWriter(ctx, runenv) defer watcher.Close() defer writer.Close() produce(t, writer, subtree, values) ch := make(chan *string, 128) - subCancel, err := watcher.Subscribe(subtree, ch) + err := watcher.Subscribe(ctx, subtree, ch) if err != nil { t.Fatal(err) } - defer subCancel() doneCh := make(chan struct{}) go func() { @@ -218,17 +225,20 @@ func TestSequenceOnWrite(t *testing.T) { subtree = randomTestSubtree() ) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + closeRedis := ensureRedis(t) defer closeRedis() s := "a" for i := 1; i <= iterations; i++ { - w, err := NewWriter(runenv) + w, err := NewWriter(ctx, runenv) if err != nil { t.Fatal(err) } - seq, err := w.Write(subtree, &s) + seq, err := w.Write(ctx, subtree, &s) if err != nil { t.Fatal(err) } @@ -250,16 +260,21 @@ func TestCloseSubscription(t *testing.T) { subtree = randomTestSubtree() ) - watcher, writer := MustWatcherWriter(runenv) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + watcher, writer := MustWatcherWriter(ctx, runenv) + + sctx, scancel := context.WithCancel(ctx) ch := make(chan *string, 128) - cancel, err := watcher.Subscribe(subtree, ch) + err := watcher.Subscribe(sctx, subtree, ch) if err != nil { t.Fatal(err) } s := "foo" - if _, err := writer.Write(subtree, &s); err != nil { + if _, err := writer.Write(ctx, subtree, &s); err != nil { t.Fatal(err) } @@ -269,9 +284,7 @@ func TestCloseSubscription(t *testing.T) { } // cancel the subscription. - if err := cancel(); err != nil { - t.Fatal(err) - } + scancel() v, ok = <-ch if ok && *v != "" { @@ -317,7 +330,7 @@ func consumeUnordered(t *testing.T, ctx context.Context, ch chan *string, values func produce(t *testing.T, writer *Writer, subtree *Subtree, values []string) { for i, s := range values { - if seq, err := writer.Write(subtree, &s); err != nil { + if seq, err := writer.Write(context.Background(), subtree, &s); err != nil { t.Fatalf("failed while writing key to subtree: %s", err) } else if seq != int64(i)+1 { t.Fatalf("expected seq == i+1; seq: %d; i: %d", seq, i) diff --git a/sdk/sync/subscription.go b/sdk/sync/subscription.go index 922d715d2..68c27977c 100644 --- a/sdk/sync/subscription.go +++ b/sdk/sync/subscription.go @@ -1,9 +1,8 @@ package sync import ( - "fmt" + "context" "reflect" - "strconv" "github.com/go-redis/redis/v7" ) @@ -19,29 +18,13 @@ type subscription struct { outCh reflect.Value - // connCh stores the connection ID of the subscription's conn. Consuming - // from here should always return a value, either a connection ID or -1 if - // no connection was created (e.g. error situation). - connCh chan int64 - closeCh chan struct{} - doneCh chan struct{} - result error -} - -func (s *subscription) isClosed() bool { - select { - case <-s.closeCh: - return true - default: - return false - } + cancel context.CancelFunc } // process subscribes to a stream from position 0 performing an indefinite // blocking XREAD. The XREAD will be cancelled when the subscription is // cancelled. func (s *subscription) process() { - defer close(s.doneCh) defer s.outCh.Close() var ( @@ -52,27 +35,38 @@ func (s *subscription) process() { startSeq, err := s.client.XLen(key).Result() if err != nil { - s.connCh <- -1 - s.result = fmt.Errorf("failed to fetch current length of stream: %w", err) + s.w.re.SLogger().Errorf("failed to fetch current length of stream: %w", err) return } log := s.w.re.SLogger().With("subtree", s.subtree, "start_seq", startSeq) - // Get a connection and store its connection ID, so that stop() can unblock - // it upon closure. + // Get a connection and store its connection ID, so we can unblock it when canceling. conn := s.client.Conn() defer conn.Close() - id, err := conn.ClientID().Result() - if err != nil { - s.connCh <- -1 - s.result = fmt.Errorf("failed to get the current conn id: %w", err) - return - } + connID, err := conn.ClientID().Result() + done := make(chan struct{}) + closed := make(chan struct{}) + go func() { + defer close(closed) + select { + case <-s.client.Context().Done(): + // we need a _non_ canceled client for this to work. + client := s.client.WithContext(context.Background()) + err := client.ClientUnblockWithError(connID).Err() + if err != nil { + log.Errorw("failed to kill connection", "error", err) + } + case <-done: + // no need to unblock anything. + } + }() - // store the conn ID in the channel. - s.connCh <- id + defer func() { + close(done) + <-closed + }() args := &redis.XReadArgs{ Streams: []string{key, "0"}, @@ -80,11 +74,11 @@ func (s *subscription) process() { } var last redis.XMessage - for !s.isClosed() { + for { streams, err := conn.XRead(args).Result() if err != nil && err != redis.Nil { - if !s.isClosed() { - s.result = fmt.Errorf("failed to XREAD from subtree stream: %w", err) + if s.client.Context().Err() == nil { + log.Errorf("failed to XREAD from subtree stream: %w", err) } return } @@ -111,28 +105,3 @@ func (s *subscription) process() { args.Streams[1] = last.ID } } - -// stop stops this subcription. -func (s *subscription) stop() error { - if s.isClosed() { - <-s.doneCh - return s.result - } - - close(s.closeCh) - - connID := <-s.connCh - - // We have a connection to close. - if connID != -1 { - // this subscription has a connection associated with it. - if err := s.client.ClientKillByFilter("id", strconv.Itoa(int(connID))).Err(); err != nil { - err := fmt.Errorf("failed to kill connection: %w", err) - s.w.re.Message("%s", err) - return err - } - } - - <-s.doneCh - return s.result -} diff --git a/sdk/sync/watch.go b/sdk/sync/watch.go index eec879ca2..dd2ca9d95 100644 --- a/sdk/sync/watch.go +++ b/sdk/sync/watch.go @@ -10,31 +10,28 @@ import ( "github.com/ipfs/testground/sdk/runtime" "github.com/go-redis/redis/v7" - "github.com/hashicorp/go-multierror" ) // Watcher exposes methods to watch subtrees within the sync tree of this test. type Watcher struct { - lk sync.RWMutex - re *runtime.RunEnv - client *redis.Client - root string - subtrees map[*Subtree]map[*subscription]struct{} + re *runtime.RunEnv + client *redis.Client + root string + subs sync.WaitGroup } // NewWatcher begins watching the subtree underneath this path. -func NewWatcher(runenv *runtime.RunEnv) (w *Watcher, err error) { - client, err := redisClient(runenv) +func NewWatcher(ctx context.Context, runenv *runtime.RunEnv) (w *Watcher, err error) { + client, err := redisClient(ctx, runenv) if err != nil { return nil, fmt.Errorf("during redisClient: %w", err) } prefix := basePrefix(runenv) w = &Watcher{ - re: runenv, - client: client, - root: prefix, - subtrees: make(map[*Subtree]map[*subscription]struct{}), + re: runenv, + client: client, + root: prefix, } return w, nil } @@ -47,57 +44,44 @@ func NewWatcher(runenv *runtime.RunEnv) (w *Watcher, err error) { // that point, the caller should consume the error (or nil value) from the // returned errCh. // -// The user can cancel the subscription by calling the returned cancelFn. The -// subscription will die if an internal error occurs, in which case the cancelFn -// should also be called. -func (w *Watcher) Subscribe(subtree *Subtree, ch interface{}) (func() error, error) { +// The user can cancel the subscription by calling the returned cancelFn or by +// canceling the passed context. The subscription will die if an internal error +// occurs. +func (w *Watcher) Subscribe(ctx context.Context, subtree *Subtree, ch interface{}) error { + if err := ctx.Err(); err != nil { + return err + } + if err := w.client.Context().Err(); err != nil { + return err + } + chV := reflect.ValueOf(ch) if k := chV.Kind(); k != reflect.Chan { - return nil, fmt.Errorf("value is not a channel: %T", ch) + return fmt.Errorf("value is not a channel: %T", ch) } if err := subtree.AssertType(chV.Type().Elem()); err != nil { chV.Close() - return nil, err - } - - w.lk.Lock() - - // Make sure we have a subtree mapping. - if _, ok := w.subtrees[subtree]; !ok { - w.subtrees[subtree] = make(map[*subscription]struct{}) + return err } root := w.root + ":" + subtree.GroupKey sub := &subscription{ w: w, subtree: subtree, - client: w.client, + client: w.client.WithContext(ctx), key: root, - connCh: make(chan int64, 1), - closeCh: make(chan struct{}), - doneCh: make(chan struct{}), outCh: chV, } - w.subtrees[subtree][sub] = struct{}{} - w.lk.Unlock() - // Start the subscription. - go sub.process() - - cancelFn := func() error { - w.lk.Lock() - defer w.lk.Unlock() - - delete(w.subtrees[subtree], sub) - if len(w.subtrees[subtree]) == 0 { - delete(w.subtrees, subtree) - } + w.subs.Add(1) + go func() { + defer w.subs.Done() + sub.process() + }() - return sub.stop() - } - return cancelFn, nil + return nil } // Barrier awaits until the specified amount of items are advertising to be in @@ -123,6 +107,7 @@ func (w *Watcher) Barrier(ctx context.Context, state State, required int64) <-ch err error ticker = time.NewTicker(250 * time.Millisecond) k = state.Key(w.root) + client = w.client.WithContext(ctx) ) defer ticker.Stop() @@ -130,7 +115,7 @@ func (w *Watcher) Barrier(ctx context.Context, state State, required int64) <-ch for last != required { select { case <-ticker.C: - last, err = w.client.Get(k).Int64() + last, err = client.Get(k).Int64() if err != nil && err != redis.Nil { err = fmt.Errorf("error occured in barrier: %w", err) resCh <- err @@ -154,16 +139,9 @@ func (w *Watcher) Barrier(ctx context.Context, state State, required int64) <-ch // Close closes this watcher. After calling this method, the watcher can't be // resused. +// +// Note: Concurrently closing the watcher while calling Subscribe may panic. func (w *Watcher) Close() error { - w.lk.Lock() - defer w.lk.Unlock() - - var result *multierror.Error - for _, st := range w.subtrees { - for sub := range st { - result = multierror.Append(result, sub.stop()) - } - } - w.subtrees = nil - return result.ErrorOrNil() + defer w.subs.Wait() + return w.client.Close() } diff --git a/sdk/sync/write.go b/sdk/sync/write.go index b27251d1f..1c038fcb9 100644 --- a/sdk/sync/write.go +++ b/sdk/sync/write.go @@ -1,6 +1,7 @@ package sync import ( + "context" "encoding/json" "reflect" "strings" @@ -33,7 +34,7 @@ type Writer struct { lk sync.RWMutex client *redis.Client re *runtime.RunEnv - doneCh chan struct{} + cancel context.CancelFunc // root is the namespace under which this test run writes. It is derived // from the RunEnv. @@ -46,47 +47,48 @@ type Writer struct { // NewWriter creates a new Writer for a specific test run, as defined by the // RunEnv. -func NewWriter(runenv *runtime.RunEnv) (w *Writer, err error) { - client, err := redisClient(runenv) +func NewWriter(ctx context.Context, runenv *runtime.RunEnv) (w *Writer, err error) { + client, err := redisClient(ctx, runenv) if err != nil { return nil, err } + exitCtx, cancel := context.WithCancel(context.Background()) w = &Writer{ client: client, re: runenv, root: basePrefix(runenv), - doneCh: make(chan struct{}), + cancel: cancel, keepAliveSet: make(map[string]struct{}), } // Start a background worker that keeps alive the keeys - go w.keepAliveWorker() + go w.keepAliveWorker(exitCtx) return w, nil } // keepAliveWorker runs a loop that extends the TTL in the keepAliveSet every // `KeepAlivePeriod`. It should be launched as a goroutine. -func (w *Writer) keepAliveWorker() { +func (w *Writer) keepAliveWorker(ctx context.Context) { for { select { case <-time.After(KeepAlivePeriod): - w.keepAlive() - case <-w.doneCh: + w.keepAlive(ctx) + case <-ctx.Done(): return } } } // keepAlive extends the TTL of all keys in the keepAliveSet. -func (w *Writer) keepAlive() { +func (w *Writer) keepAlive(ctx context.Context) { w.lk.RLock() defer w.lk.RUnlock() // TODO: do this in a transaction. We risk the loop overlapping with the // refresh period, and all kinds of races. We need to be adaptive here. for k, _ := range w.keepAliveSet { - if err := w.client.Expire(k, TTL).Err(); err != nil { + if err := w.client.WithContext(ctx).Expire(k, TTL).Err(); err != nil { panic(err) } } @@ -102,7 +104,7 @@ func (w *Writer) keepAlive() { // // Else, if all succeeds, it returns the ordinal sequence number of this entry // within the subtree (starting at 1). -func (w *Writer) Write(subtree *Subtree, payload interface{}) (seq int64, err error) { +func (w *Writer) Write(ctx context.Context, subtree *Subtree, payload interface{}) (seq int64, err error) { if err = subtree.AssertType(reflect.ValueOf(payload).Type()); err != nil { return -1, err } @@ -119,7 +121,7 @@ func (w *Writer) Write(subtree *Subtree, payload interface{}) (seq int64, err er // Perform a Redis transaction, adding the item to the stream and fetching // the XLEN of the stream. var xlen *redis.IntCmd - _, err = w.client.TxPipelined(func(pipe redis.Pipeliner) error { + _, err = w.client.WithContext(ctx).TxPipelined(func(pipe redis.Pipeliner) error { pipe.XAdd(&redis.XAddArgs{ Stream: key, ID: "*", @@ -160,14 +162,14 @@ func (w *Writer) Write(subtree *Subtree, payload interface{}) (seq int64, err er // SignalEntry signals entry into the specified state, and returns how many // instances are currently in this state, including the caller. -func (w *Writer) SignalEntry(s State) (current int64, err error) { +func (w *Writer) SignalEntry(ctx context.Context, s State) (current int64, err error) { log := w.re.SLogger() log.Debugw("signalling entry to state", "state", s) // Increment a counter on the state key. key := strings.Join([]string{w.root, "states", string(s)}, ":") - seq, err := w.client.Incr(key).Result() + seq, err := w.client.WithContext(ctx).Incr(key).Result() if err != nil { return -1, err } @@ -188,7 +190,7 @@ func (w *Writer) SignalEntry(s State) (current int64, err error) { // Close closes this Writer, and drops all owned keys immediately, erroring if // those deletions fail. func (w *Writer) Close() error { - close(w.doneCh) + w.cancel() w.lk.Lock() defer w.lk.Unlock()