From 105adf73741180e8c0678a68298ca0f4b736e758 Mon Sep 17 00:00:00 2001 From: Evan Wallace Date: Thu, 3 Nov 2022 00:13:27 -0400 Subject: [PATCH] attempted fix for #2485: `Add` and `Wait` safety --- cmd/esbuild/service.go | 78 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 68 insertions(+), 10 deletions(-) diff --git a/cmd/esbuild/service.go b/cmd/esbuild/service.go index 65c0a637f35..834c8f7c474 100644 --- a/cmd/esbuild/service.go +++ b/cmd/esbuild/service.go @@ -47,6 +47,7 @@ type serviceType struct { keepAliveWaitGroup sync.WaitGroup mutex sync.Mutex nextRequestID uint32 + disableSendRequest bool } func (service *serviceType) getActiveBuild(key int) *activeBuild { @@ -108,10 +109,7 @@ func runService(sendPings bool) { // Write packets on a single goroutine so they aren't interleaved go func() { for { - packet, ok := <-service.outgoingPackets - if !ok { - break // No more packets - } + packet := <-service.outgoingPackets if _, err := os.Stdout.Write(packet.bytes); err != nil { os.Exit(1) // I/O error } @@ -122,6 +120,37 @@ func runService(sendPings bool) { // The protocol always starts with the version os.Stdout.Write(append(writeUint32(nil, uint32(len(esbuildVersion))), esbuildVersion...)) + // IMPORTANT: To avoid a data race, we must ensure that calling "Add()" on a + // "WaitGroup" with a counter of zero cannot possibly happen concurrently + // with calling "Wait()" on another goroutine. Thus we must ensure that the + // counter starts off positive before "Wait()" is called and that it only + // ever reaches zero exactly once (and that the counter never goes back above + // zero once it reaches zero). See this for discussion and more information: + // https://github.com/evanw/esbuild/issues/2485#issuecomment-1299318498 + service.keepAliveWaitGroup.Add(1) + defer func() { + // Stop all future calls to "sendRequest()" from calling "Add()" on our + // "WaitGroup" while we're calling "Wait()", since it may have a counter of + // zero at that point. This is a mutex so calls to "sendRequest()" must + // fall into one of two cases: + // + // a) The critical section in "sendRequest()" comes before this critical + // section and "Add()" is called while the counter is non-zero, which + // is fine. + // + // b) The critical section in "sendRequest()" comes after this critical + // section and it does not call "Add()", which is also fine. + // + service.mutex.Lock() + service.disableSendRequest = true + service.keepAliveWaitGroup.Done() + service.mutex.Unlock() + + // Wait for the last response to be written to stdout before returning from + // the enclosing function, which will return from "main()" and exit. + service.keepAliveWaitGroup.Wait() + }() + // Periodically ping the host even when we're idle. This will catch cases // where the host has disappeared and will never send us anything else but // we incorrectly think we are still needed. In that case we will now try @@ -173,11 +202,11 @@ func runService(sendPings bool) { // Move the remaining partial packet to the end to avoid reallocating stream = append(stream[:0], bytes...) } - - // Wait for the last response to be written to stdout - service.keepAliveWaitGroup.Wait() } +// This will either block until the request has been sent and a response has +// been received, or it will return nil to indicate failure to send due to +// stdin being closed. func (service *serviceType) sendRequest(request interface{}) interface{} { result := make(chan interface{}) var id uint32 @@ -193,7 +222,27 @@ func (service *serviceType) sendRequest(request interface{}) interface{} { service.callbacks[id] = callback return id }() + + // This function can be called from any thread. For example, it might be called + // by the implementation of watch or serve mode, or by esbuild's keep-alive + // timer goroutine. To avoid data races, we must ensure that it's not possible + // for "Add()" to be called on a "WaitGroup" with a counter of zero at the + // same time as "Wait()" is called on another goroutine. + // + // There's a potential data race when the stdin thread has finished reading + // from stdin because stdin has been closed but while it's calling "Wait()" + // on the "WaitGroup" with a counter of zero, some other thread calls + // "sendRequest()" which calls "Add()". + // + // This data race is prevented by not sending any more requests once the + // stdin thread has finished. This is ok because we are about to exit. + service.mutex.Lock() + if service.disableSendRequest { + service.mutex.Unlock() + return nil + } service.keepAliveWaitGroup.Add(1) // The writer thread will call "Done()" + service.mutex.Unlock() service.outgoingPackets <- outgoingPacket{ bytes: encodePacket(packet{ id: id, @@ -769,10 +818,13 @@ func (service *serviceType) convertPlugins(key int, jsPlugins interface{}, activ build.OnStart(func() (api.OnStartResult, error) { result := api.OnStartResult{} - response := service.sendRequest(map[string]interface{}{ + response, ok := service.sendRequest(map[string]interface{}{ "command": "on-start", "key": key, }).(map[string]interface{}) + if !ok { + return result, errors.New("The service was stopped") + } if value, ok := response["errors"]; ok { result.Errors = decodeMessages(value.([]interface{})) @@ -798,7 +850,7 @@ func (service *serviceType) convertPlugins(key int, jsPlugins interface{}, activ return result, nil } - response := service.sendRequest(map[string]interface{}{ + response, ok := service.sendRequest(map[string]interface{}{ "command": "on-resolve", "key": key, "ids": ids, @@ -809,6 +861,9 @@ func (service *serviceType) convertPlugins(key int, jsPlugins interface{}, activ "kind": resolveKindToString(args.Kind), "pluginData": args.PluginData, }).(map[string]interface{}) + if !ok { + return result, errors.New("The service was stopped") + } if value, ok := response["id"]; ok { id := value.(int) @@ -883,7 +938,7 @@ func (service *serviceType) convertPlugins(key int, jsPlugins interface{}, activ return result, nil } - response := service.sendRequest(map[string]interface{}{ + response, ok := service.sendRequest(map[string]interface{}{ "command": "on-load", "key": key, "ids": ids, @@ -892,6 +947,9 @@ func (service *serviceType) convertPlugins(key int, jsPlugins interface{}, activ "suffix": args.Suffix, "pluginData": args.PluginData, }).(map[string]interface{}) + if !ok { + return result, errors.New("The service was stopped") + } if value, ok := response["id"]; ok { id := value.(int)