diff --git a/channel.go b/channel.go index 0499368f..6a0115f3 100644 --- a/channel.go +++ b/channel.go @@ -70,6 +70,5 @@ func newChannel(owner *channelOwner, object interface{}) *channel { owner: owner, object: object, } - channel.initEventEmitter() return channel } diff --git a/channel_owner.go b/channel_owner.go index 6c611499..d858ac87 100644 --- a/channel_owner.go +++ b/channel_owner.go @@ -93,7 +93,6 @@ func (c *channelOwner) createChannelOwner(self interface{}, parent *channelOwner } c.channel = newChannel(c, self) c.eventToSubscriptionMapping = map[string]string{} - c.initEventEmitter() } type rootChannelOwner struct { diff --git a/connection.go b/connection.go index db1c4676..ab7cf3af 100644 --- a/connection.go +++ b/connection.go @@ -28,8 +28,7 @@ type connection struct { transport transport apiZone sync.Map objects map[string]*channelOwner - lastID int - lastIDLock sync.Mutex + lastID atomic.Uint32 rootObject *rootChannelOwner callbacks sync.Map afterClose func() @@ -97,7 +96,7 @@ func (c *connection) Dispatch(msg *message) { } method := msg.Method if msg.ID != 0 { - cb, _ := c.callbacks.LoadAndDelete(msg.ID) + cb, _ := c.callbacks.LoadAndDelete(uint32(msg.ID)) if cb.(*protocolCallback).noReply { return } @@ -226,10 +225,7 @@ func (c *connection) sendMessageToServer(object *channelOwner, method string, pa return nil, errors.New("The object has been collected to prevent unbounded heap growth.") } - c.lastIDLock.Lock() - c.lastID++ - id := c.lastID - c.lastIDLock.Unlock() + id := c.lastID.Add(1) cb, _ := c.callbacks.LoadOrStore(id, newProtocolCallback(noReply, c.abort)) var ( metadata = make(map[string]interface{}, 0) @@ -356,7 +352,7 @@ func fromNullableChannel(v interface{}) interface{} { } type protocolCallback struct { - Callback chan result + callback chan result noReply bool abort <-chan struct{} } @@ -367,8 +363,12 @@ func (pc *protocolCallback) SetResult(r result) { } select { case <-pc.abort: + select { + case pc.callback <- r: + default: + } return - case pc.Callback <- r: + case pc.callback <- r: } } @@ -377,10 +377,15 @@ func (pc *protocolCallback) GetResult() (interface{}, error) { return nil, nil } select { - case result := <-pc.Callback: + case result := <-pc.callback: return result.Data, result.Error case <-pc.abort: - return nil, errors.New("Connection closed") + select { + case result := <-pc.callback: + return result.Data, result.Error + default: + return nil, errors.New("Connection closed") + } } } @@ -392,7 +397,7 @@ func newProtocolCallback(noReply bool, abort <-chan struct{}) *protocolCallback } } return &protocolCallback{ - Callback: make(chan result), + callback: make(chan result, 1), abort: abort, } } diff --git a/event_emitter.go b/event_emitter.go index 0c2a9f8f..7e534b93 100644 --- a/event_emitter.go +++ b/event_emitter.go @@ -4,6 +4,8 @@ import ( "math" "reflect" "sync" + + "golang.org/x/exp/slices" ) type EventEmitter interface { @@ -15,44 +17,33 @@ type EventEmitter interface { } type ( - eventRegister struct { - once []interface{} - on []interface{} - } eventEmitter struct { eventsMutex sync.Mutex events map[string]*eventRegister + hasInit bool + } + eventRegister struct { + listeners []listener + } + listener struct { + handler interface{} + once bool } ) -func (e *eventEmitter) Emit(name string, payload ...interface{}) (handled bool) { +func (e *eventEmitter) Emit(name string, payload ...interface{}) (hasListener bool) { e.eventsMutex.Lock() defer e.eventsMutex.Unlock() - if _, ok := e.events[name]; !ok { - return - } - - if len(e.events[name].once) > 0 || len(e.events[name].on) > 0 { - handled = true - } + e.init() - payloadV := make([]reflect.Value, 0) - - for _, p := range payload { - payloadV = append(payloadV, reflect.ValueOf(p)) - } - - callHandlers := func(handlers []interface{}) { - for _, handler := range handlers { - handlerV := reflect.ValueOf(handler) - handlerV.Call(payloadV[:int(math.Min(float64(handlerV.Type().NumIn()), float64(len(payloadV))))]) - } + evt, ok := e.events[name] + if !ok { + return } - callHandlers(e.events[name].on) - callHandlers(e.events[name].once) + hasListener = evt.count() > 0 - e.events[name].once = make([]interface{}, 0) + evt.callHandlers(payload...) return } @@ -67,60 +58,88 @@ func (e *eventEmitter) On(name string, handler interface{}) { func (e *eventEmitter) RemoveListener(name string, handler interface{}) { e.eventsMutex.Lock() defer e.eventsMutex.Unlock() + e.init() + if _, ok := e.events[name]; !ok { return } - handlerPtr := reflect.ValueOf(handler).Pointer() + e.events[name].removeHandler(handler) +} - onHandlers := []interface{}{} - for idx := range e.events[name].on { - eventPtr := reflect.ValueOf(e.events[name].on[idx]).Pointer() - if eventPtr != handlerPtr { - onHandlers = append(onHandlers, e.events[name].on[idx]) - } - } - e.events[name].on = onHandlers +// ListenerCount count the listeners by name, count all if name is empty +func (e *eventEmitter) ListenerCount(name string) int { + e.eventsMutex.Lock() + defer e.eventsMutex.Unlock() + e.init() - onceHandlers := []interface{}{} - for idx := range e.events[name].once { - eventPtr := reflect.ValueOf(e.events[name].once[idx]).Pointer() - if eventPtr != handlerPtr { - onceHandlers = append(onceHandlers, e.events[name].once[idx]) + if name != "" { + evt, ok := e.events[name] + if !ok { + return 0 } + return evt.count() } - e.events[name].once = onceHandlers -} - -// ListenerCount count the listeners by name, count all if name is empty -func (e *eventEmitter) ListenerCount(name string) int { count := 0 - e.eventsMutex.Lock() for key := range e.events { - if name == "" || name == key { - count += len(e.events[key].on) + len(e.events[key].once) - } + count += e.events[key].count() } - e.eventsMutex.Unlock() + return count } func (e *eventEmitter) addEvent(name string, handler interface{}, once bool) { e.eventsMutex.Lock() + e.init() + if _, ok := e.events[name]; !ok { e.events[name] = &eventRegister{ - on: make([]interface{}, 0), - once: make([]interface{}, 0), + listeners: make([]listener, 0), } } - if once { - e.events[name].once = append(e.events[name].once, handler) - } else { - e.events[name].on = append(e.events[name].on, handler) - } + e.events[name].addHandler(handler, once) e.eventsMutex.Unlock() } -func (e *eventEmitter) initEventEmitter() { - e.events = make(map[string]*eventRegister) +func (e *eventEmitter) init() { + if !e.hasInit { + e.events = make(map[string]*eventRegister, 0) + e.hasInit = true + } +} + +func (e *eventRegister) addHandler(handler interface{}, once bool) { + e.listeners = append(e.listeners, listener{handler: handler, once: once}) +} + +func (e *eventRegister) count() int { + return len(e.listeners) +} + +func (e *eventRegister) removeHandler(handler interface{}) { + handlerPtr := reflect.ValueOf(handler).Pointer() + + e.listeners = slices.DeleteFunc[[]listener](e.listeners, func(l listener) bool { + return reflect.ValueOf(l.handler).Pointer() == handlerPtr + }) +} + +func (e *eventRegister) callHandlers(payloads ...interface{}) { + payloadV := make([]reflect.Value, 0) + + for _, p := range payloads { + payloadV = append(payloadV, reflect.ValueOf(p)) + } + + handle := func(l listener) { + handlerV := reflect.ValueOf(l.handler) + handlerV.Call(payloadV[:int(math.Min(float64(handlerV.Type().NumIn()), float64(len(payloadV))))]) + } + + for _, l := range e.listeners { + if l.once { + defer e.removeHandler(l.handler) + } + handle(l) + } } diff --git a/event_emitter_test.go b/event_emitter_test.go index e939d648..6fad9290 100644 --- a/event_emitter_test.go +++ b/event_emitter_test.go @@ -14,7 +14,6 @@ const ( func TestEventEmitterListenerCount(t *testing.T) { handler := &eventEmitter{} - handler.initEventEmitter() wasCalled := make(chan interface{}, 1) myHandler := func(payload ...interface{}) { wasCalled <- payload[0] @@ -32,7 +31,6 @@ func TestEventEmitterListenerCount(t *testing.T) { func TestEventEmitterOn(t *testing.T) { handler := &eventEmitter{} - handler.initEventEmitter() wasCalled := make(chan interface{}, 1) require.Nil(t, handler.events[testEventName]) handler.On(testEventName, func(payload ...interface{}) { @@ -48,7 +46,6 @@ func TestEventEmitterOn(t *testing.T) { func TestEventEmitterOnce(t *testing.T) { handler := &eventEmitter{} - handler.initEventEmitter() wasCalled := make(chan interface{}, 1) require.Nil(t, handler.events[testEventName]) handler.Once(testEventName, func(payload ...interface{}) { @@ -64,7 +61,6 @@ func TestEventEmitterOnce(t *testing.T) { func TestEventEmitterRemove(t *testing.T) { handler := &eventEmitter{} - handler.initEventEmitter() wasCalled := make(chan interface{}, 1) require.Nil(t, handler.events[testEventName]) myHandler := func(payload ...interface{}) { @@ -84,14 +80,12 @@ func TestEventEmitterRemove(t *testing.T) { func TestEventEmitterRemoveEmpty(t *testing.T) { handler := &eventEmitter{} - handler.initEventEmitter() handler.RemoveListener(testEventName, func(...interface{}) {}) require.Equal(t, 0, handler.ListenerCount(testEventName)) } func TestEventEmitterRemoveKeepExisting(t *testing.T) { handler := &eventEmitter{} - handler.initEventEmitter() handler.On(testEventName, func(...interface{}) {}) handler.Once(testEventName, func(...interface{}) {}) handler.RemoveListener("abc123", func(...interface{}) {}) @@ -101,7 +95,6 @@ func TestEventEmitterRemoveKeepExisting(t *testing.T) { func TestEventEmitterOnLessArgsAcceptingReceiver(t *testing.T) { handler := &eventEmitter{} - handler.initEventEmitter() wasCalled := make(chan bool, 1) require.Nil(t, handler.events[testEventName]) handler.Once(testEventName, func(ev ...interface{}) { diff --git a/go.mod b/go.mod index c5a8fdb2..c5c96c6b 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/stretchr/testify v1.8.4 github.com/tidwall/gjson v1.17.0 go.uber.org/multierr v1.11.0 + golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc ) require ( diff --git a/go.sum b/go.sum index e374082d..2bdc5198 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc h1:ao2WRsKSzW6KuUY9IWPwWahcHCgR0s52IfwutMfEbdM= +golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= diff --git a/local_utils.go b/local_utils.go index b3c06697..2a882bfe 100644 --- a/local_utils.go +++ b/local_utils.go @@ -138,7 +138,7 @@ func (l *localUtilsImpl) TraceDiscarded(stacksId string) error { return err } -func (l *localUtilsImpl) AddStackToTracingNoReply(id int, stack []map[string]interface{}) { +func (l *localUtilsImpl) AddStackToTracingNoReply(id uint32, stack []map[string]interface{}) { l.channel.SendNoReply("addStackToTracingNoReply", map[string]interface{}{ "callData": map[string]interface{}{ "id": id, diff --git a/tests/browser_context_test.go b/tests/browser_context_test.go index 8e75de7f..aa70c244 100644 --- a/tests/browser_context_test.go +++ b/tests/browser_context_test.go @@ -520,3 +520,18 @@ func TestPageErrorEventShouldWork(t *testing.T) { require.Equal(t, page, weberror.Page()) require.ErrorContains(t, weberror.Error(), "boom") } + +func TestBrowserContextOnResponse(t *testing.T) { + BeforeEach(t) + defer AfterEach(t) + responseChan := make(chan playwright.Response, 1) + context.OnResponse(func(response playwright.Response) { + responseChan <- response + }) + _, err := page.Goto(fmt.Sprintf("%s/title.html", server.PREFIX)) + require.NoError(t, err) + response := <-responseChan + body, err := response.Body() + require.NoError(t, err) + require.Equal(t, "Woof-Woof\n", string(body)) +} diff --git a/tests/browser_test.go b/tests/browser_test.go index 868e5af4..731cdf31 100644 --- a/tests/browser_test.go +++ b/tests/browser_test.go @@ -14,6 +14,12 @@ func TestBrowserIsConnected(t *testing.T) { require.True(t, browser.IsConnected()) } +func TestBrowserShouldReturnBrowserType(t *testing.T) { + BeforeEach(t) + defer AfterEach(t) + require.Equal(t, browserType, browser.BrowserType()) +} + func TestBrowserVersion(t *testing.T) { BeforeEach(t) defer AfterEach(t) @@ -23,7 +29,11 @@ func TestBrowserVersion(t *testing.T) { func TestBrowserNewContext(t *testing.T) { BeforeEach(t) defer AfterEach(t) - require.Equal(t, 1, len(context.Pages())) + context2, err := browser.NewContext() + require.NoError(t, err) + require.Equal(t, 2, len(browser.Contexts())) + require.NoError(t, context2.Close()) + require.Equal(t, 1, len(browser.Contexts())) } func TestBrowserNewContextWithExtraHTTPHeaders(t *testing.T) { diff --git a/tests/browser_type_test.go b/tests/browser_type_test.go index 80d427aa..214ee3bf 100644 --- a/tests/browser_type_test.go +++ b/tests/browser_type_test.go @@ -126,8 +126,8 @@ func TestBrowserTypeConnectShouldEmitDisconnectedEvent(t *testing.T) { defer AfterEach(t) remoteServer, err := newRemoteServer() require.NoError(t, err) - disconnected1 := newSyncSlice() - disconnected2 := newSyncSlice() + disconnected1 := newSyncSlice[bool]() + disconnected2 := newSyncSlice[bool]() browser1, err := browserType.Connect(remoteServer.url) require.NoError(t, err) require.NotNil(t, browser1) diff --git a/tests/console_message_test.go b/tests/console_message_test.go index 475c9bfd..bb26302d 100644 --- a/tests/console_message_test.go +++ b/tests/console_message_test.go @@ -85,28 +85,28 @@ func TestConsoleShouldWorkForDifferentConsoleAPICalls(t *testing.T) { console.error('calling console.error'); console.log(Promise.resolve('should not wait until resolved!')); }`) - messages := ChanToSlice(messagesChan, 6).([]playwright.ConsoleMessage) + messages := ChanToSlice(messagesChan, 6) require.NoError(t, err) - require.Equal(t, []interface{}{ + require.Equal(t, []string{ "timeEnd", "trace", "dir", "warning", "error", "log", - }, Map(messages, func(msg interface{}) interface{} { - return msg.(playwright.ConsoleMessage).Type() + }, Map(messages, func(msg playwright.ConsoleMessage) string { + return msg.Type() })) require.Contains(t, messages[0].Text(), "calling console.time") - require.Equal(t, []interface{}{ + require.Equal(t, []string{ "calling console.trace", "calling console.dir", "calling console.warn", "calling console.error", "Promise", - }, Map(messages[1:], func(msg interface{}) interface{} { - return msg.(playwright.ConsoleMessage).Text() + }, Map(messages[1:], func(msg playwright.ConsoleMessage) string { + return msg.Text() })) } diff --git a/tests/frame_locator_test.go b/tests/frame_locator_test.go index f3020095..1a72c6a7 100644 --- a/tests/frame_locator_test.go +++ b/tests/frame_locator_test.go @@ -79,15 +79,32 @@ func routeAmbiguous(t *testing.T, page playwright.Page) { } func TestFrameLocatorFirst(t *testing.T) { - BeforeEach(t) - defer AfterEach(t) - routeAmbiguous(t, page) - _, err := page.Goto(server.EMPTY_PAGE) - require.NoError(t, err) - innerText, err := page.Locator("body").FrameLocator("iframe").First().GetByRole("button").InnerText() - require.NoError(t, err) - require.Equal(t, "Hello from iframe-1.html", innerText) + t.Run("basic", func(t *testing.T) { + + BeforeEach(t) + defer AfterEach(t) + routeAmbiguous(t, page) + _, err := page.Goto(server.EMPTY_PAGE) + require.NoError(t, err) + + innerText, err := page.Locator("body").FrameLocator("iframe").First().GetByRole("button").InnerText() + require.NoError(t, err) + require.Equal(t, "Hello from iframe-1.html", innerText) + }) + + t.Run("ambiguous", func(t *testing.T) { + BeforeEach(t) + defer AfterEach(t) + routeAmbiguous(t, page) + _, err := page.Goto(server.EMPTY_PAGE) + require.NoError(t, err) + + innerText, err := page.Locator("body").FrameLocator("iframe").Nth(1).Locator("button").InnerText() + require.NoError(t, err) + require.Equal(t, "Hello from iframe-2.html", innerText) + }) + } func TestFrameLocatorNth(t *testing.T) { diff --git a/tests/frame_test.go b/tests/frame_test.go index c08dd15e..46c87a74 100644 --- a/tests/frame_test.go +++ b/tests/frame_test.go @@ -216,3 +216,18 @@ func TestFrameParent(t *testing.T) { require.Equal(t, page.MainFrame(), frames[1].ParentFrame()) require.Equal(t, page.MainFrame(), frames[2].ParentFrame()) } + +func TestFrameShouldHandleNestedFrames(t *testing.T) { + BeforeEach(t) + defer AfterEach(t) + _, err := page.Goto(server.PREFIX + "/frames/nested-frames.html") + require.NoError(t, err) + dump := utils.DumpFrames(page.MainFrame(), "") + require.Equal(t, []string{ + "http://localhost:/frames/nested-frames.html", + " http://localhost:/frames/frame.html (aframe)", + " http://localhost:/frames/two-frames.html (2frames)", + " http://localhost:/frames/frame.html (dos)", + " http://localhost:/frames/frame.html (uno)", + }, dump) +} diff --git a/tests/helper_test.go b/tests/helper_test.go index 4e93b00c..40eba382 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -13,7 +13,7 @@ import ( "net/http/httptest" "os" "path/filepath" - "reflect" + "sort" "strings" "sync" "testing" @@ -209,52 +209,43 @@ func (s *testServer) WaitForRequestChan(path string) <-chan *http.Request { return channel } -func Map(vs interface{}, f func(interface{}) interface{}) []interface{} { - v := reflect.ValueOf(vs) - vsm := make([]interface{}, v.Len()) - for i := 0; i < v.Len(); i++ { - vsm[i] = f(v.Index(i).Interface()) +func Map[T any, R any](vs []T, f func(T) R) []R { + vsm := make([]R, len(vs)) + for i, v := range vs { + vsm[i] = f(v) } return vsm } -// ChanToSlice reads all data from ch (which must be a chan), returning a -// slice of the data. If ch is a 'T chan' then the return value is of type -// []T inside the returned interface. -// A typical call would be sl := ChanToSlice(ch).([]int) -func ChanToSlice(ch interface{}, amount int) interface{} { - chv := reflect.ValueOf(ch) - slv := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(ch).Elem()), 0, 0) +// ChanToSlice reads amount of values from the channel, returns them as a slice +func ChanToSlice[T any](ch chan T, amount int) []T { + data := make([]T, 0) for i := 0; i < amount; i++ { - v, ok := chv.Recv() - if !ok { - return slv.Interface() - } - slv = reflect.Append(slv, v) + data = append(data, <-ch) } - return slv.Interface() + return data } -type syncSlice struct { +type syncSlice[T any] struct { sync.Mutex - slice []interface{} + slice []T } -func (s *syncSlice) Append(v interface{}) { +func (s *syncSlice[T]) Append(v T) { s.Lock() + defer s.Unlock() s.slice = append(s.slice, v) - s.Unlock() } -func (s *syncSlice) Get() interface{} { +func (s *syncSlice[T]) Get() []T { s.Lock() defer s.Unlock() return s.slice } -func newSyncSlice() *syncSlice { - return &syncSlice{ - slice: make([]interface{}, 0), +func newSyncSlice[T any]() *syncSlice[T] { + return &syncSlice[T]{ + slice: make([]T, 0), } } @@ -288,6 +279,24 @@ func (t *testUtils) DetachFrame(page playwright.Page, frameId string) error { return err } +func (tu *testUtils) DumpFrames(frame playwright.Frame, indentation string) []string { + desc := strings.Replace(frame.URL(), server.PREFIX, "http://localhost:", 1) + if frame.Name() != "" { + desc = fmt.Sprintf("%s (%s)", desc, frame.Name()) + } + result := []string{ + indentation + desc, + } + sortedFrames := frame.ChildFrames() + sort.SliceStable(sortedFrames, func(i, j int) bool { + return (sortedFrames[i].URL() + sortedFrames[i].Name()) < (sortedFrames[j].URL() + sortedFrames[j].Name()) + }) + for _, f := range sortedFrames { + result = append(result, tu.DumpFrames(f, " "+indentation)...) + } + return result +} + func (tu *testUtils) VerifyViewport(t *testing.T, page playwright.Page, width, height int) { require.Equal(t, page.ViewportSize().Width, width) require.Equal(t, page.ViewportSize().Height, height) diff --git a/tests/route_test.go b/tests/route_test.go index 8c74697a..5b431175 100644 --- a/tests/route_test.go +++ b/tests/route_test.go @@ -174,7 +174,7 @@ func TestRouteFulfillPath(t *testing.T) { func TestRequestFinished(t *testing.T) { BeforeEach(t) defer AfterEach(t) - eventsStorage := newSyncSlice() + eventsStorage := newSyncSlice[string]() var request playwright.Request page.Once("request", func(r playwright.Request) { request = r @@ -187,7 +187,7 @@ func TestRequestFinished(t *testing.T) { require.NoError(t, err) require.NoError(t, response.Finished()) eventsStorage.Append("requestfinished") - require.Equal(t, []interface{}{"request", "response", "requestfinished"}, eventsStorage.Get()) + require.Equal(t, []string{"request", "response", "requestfinished"}, eventsStorage.Get()) require.Equal(t, response.Request(), request) require.Equal(t, response.Frame(), page.MainFrame()) } diff --git a/video.go b/video.go index 85b5b128..a57b61ab 100644 --- a/video.go +++ b/video.go @@ -9,6 +9,7 @@ type videoImpl struct { page *pageImpl artifact *artifactImpl artifactChan chan *artifactImpl + done chan struct{} closeOnce sync.Once isRemote bool } @@ -49,9 +50,7 @@ func (v *videoImpl) artifactReady(artifact *artifactImpl) { func (v *videoImpl) pageClosed(p Page) { v.closeOnce.Do(func() { - if v.artifactChan != nil { - close(v.artifactChan) - } + close(v.done) }) } @@ -65,19 +64,31 @@ func (v *videoImpl) getArtifact() { v.pageClosed(v.page) } } - artifact := <-v.artifactChan - if artifact != nil { - v.artifact = artifact + select { + case artifact := <-v.artifactChan: + if artifact != nil { + v.artifact = artifact + } + case <-v.done: // page closed + select { // make sure get artifact if it's ready before page closed + case artifact := <-v.artifactChan: + if artifact != nil { + v.artifact = artifact + } + default: + } } } func newVideo(page *pageImpl) *videoImpl { video := &videoImpl{ - page: page, - isRemote: page.connection.isRemote, + page: page, + artifactChan: make(chan *artifactImpl, 1), + done: make(chan struct{}, 1), + isRemote: page.connection.isRemote, } - video.artifactChan = make(chan *artifactImpl, 1) - if page.IsClosed() { + + if page.isClosed { video.pageClosed(page) } else { page.OnClose(video.pageClosed) diff --git a/waiter_test.go b/waiter_test.go index bf4b3c5b..0afed58d 100644 --- a/waiter_test.go +++ b/waiter_test.go @@ -17,7 +17,6 @@ const ( func TestWaiterWaitForEvent(t *testing.T) { timeout := 500.0 emitter := &eventEmitter{} - emitter.initEventEmitter() waiter := newWaiter().WithTimeout(timeout) _, err := waiter.Wait() require.Error(t, err) @@ -35,7 +34,6 @@ func TestWaiterWaitForEvent(t *testing.T) { func TestWaiterWaitForEventWithPredicate(t *testing.T) { timeout := 500.0 emitter := &eventEmitter{} - emitter.initEventEmitter() waiter := newWaiter().WithTimeout(timeout) waiter.WaitForEvent(emitter, testEventNameFoobar, func(payload interface{}) bool { content, ok := payload.(string) @@ -57,7 +55,6 @@ func TestWaiterWaitForEventWithPredicate(t *testing.T) { func TestWaiterRejectOnTimeout(t *testing.T) { timeout := 300.0 emitter := &eventEmitter{} - emitter.initEventEmitter() waiter := newWaiter().WithTimeout(timeout) waiter.WaitForEvent(emitter, testEventNameFoobar, nil) go func() { @@ -73,7 +70,6 @@ func TestWaiterRejectOnEvent(t *testing.T) { errCause := fmt.Errorf("reject on event") errPredicate := fmt.Errorf("payload on event") emitter := &eventEmitter{} - emitter.initEventEmitter() waiter := newWaiter().RejectOnEvent(emitter, testEventNameReject, errCause) waiter.RejectOnEvent(emitter, testEventNameFoobar, errPredicate, func(payload interface{}) bool { content, ok := payload.(string) @@ -99,7 +95,6 @@ func TestWaiterRejectOnEventWithPredicate(t *testing.T) { errCause := fmt.Errorf("reject on event") errPredicate := fmt.Errorf("payload on event") emitter := &eventEmitter{} - emitter.initEventEmitter() waiter := newWaiter().RejectOnEvent(emitter, testEventNameReject, errCause) waiter.RejectOnEvent(emitter, testEventNameFoobar, errPredicate, func(payload interface{}) bool { content, ok := payload.(string) @@ -123,7 +118,6 @@ func TestWaiterRejectOnEventWithPredicate(t *testing.T) { func TestWaiterReturnErrorWhenMisuse(t *testing.T) { emitter := &eventEmitter{} - emitter.initEventEmitter() waiter := newWaiter() waiter.WaitForEvent(emitter, testEventNameFoobar, nil) waiter.WithTimeout(500)