diff --git a/channel.go b/channel.go
index 0499368f..6a0115f3 100644
--- a/channel.go
+++ b/channel.go
@@ -70,6 +70,5 @@ func newChannel(owner *channelOwner, object interface{}) *channel {
owner: owner,
object: object,
}
- channel.initEventEmitter()
return channel
}
diff --git a/channel_owner.go b/channel_owner.go
index 6c611499..d858ac87 100644
--- a/channel_owner.go
+++ b/channel_owner.go
@@ -93,7 +93,6 @@ func (c *channelOwner) createChannelOwner(self interface{}, parent *channelOwner
}
c.channel = newChannel(c, self)
c.eventToSubscriptionMapping = map[string]string{}
- c.initEventEmitter()
}
type rootChannelOwner struct {
diff --git a/connection.go b/connection.go
index db1c4676..ab7cf3af 100644
--- a/connection.go
+++ b/connection.go
@@ -28,8 +28,7 @@ type connection struct {
transport transport
apiZone sync.Map
objects map[string]*channelOwner
- lastID int
- lastIDLock sync.Mutex
+ lastID atomic.Uint32
rootObject *rootChannelOwner
callbacks sync.Map
afterClose func()
@@ -97,7 +96,7 @@ func (c *connection) Dispatch(msg *message) {
}
method := msg.Method
if msg.ID != 0 {
- cb, _ := c.callbacks.LoadAndDelete(msg.ID)
+ cb, _ := c.callbacks.LoadAndDelete(uint32(msg.ID))
if cb.(*protocolCallback).noReply {
return
}
@@ -226,10 +225,7 @@ func (c *connection) sendMessageToServer(object *channelOwner, method string, pa
return nil, errors.New("The object has been collected to prevent unbounded heap growth.")
}
- c.lastIDLock.Lock()
- c.lastID++
- id := c.lastID
- c.lastIDLock.Unlock()
+ id := c.lastID.Add(1)
cb, _ := c.callbacks.LoadOrStore(id, newProtocolCallback(noReply, c.abort))
var (
metadata = make(map[string]interface{}, 0)
@@ -356,7 +352,7 @@ func fromNullableChannel(v interface{}) interface{} {
}
type protocolCallback struct {
- Callback chan result
+ callback chan result
noReply bool
abort <-chan struct{}
}
@@ -367,8 +363,12 @@ func (pc *protocolCallback) SetResult(r result) {
}
select {
case <-pc.abort:
+ select {
+ case pc.callback <- r:
+ default:
+ }
return
- case pc.Callback <- r:
+ case pc.callback <- r:
}
}
@@ -377,10 +377,15 @@ func (pc *protocolCallback) GetResult() (interface{}, error) {
return nil, nil
}
select {
- case result := <-pc.Callback:
+ case result := <-pc.callback:
return result.Data, result.Error
case <-pc.abort:
- return nil, errors.New("Connection closed")
+ select {
+ case result := <-pc.callback:
+ return result.Data, result.Error
+ default:
+ return nil, errors.New("Connection closed")
+ }
}
}
@@ -392,7 +397,7 @@ func newProtocolCallback(noReply bool, abort <-chan struct{}) *protocolCallback
}
}
return &protocolCallback{
- Callback: make(chan result),
+ callback: make(chan result, 1),
abort: abort,
}
}
diff --git a/event_emitter.go b/event_emitter.go
index 0c2a9f8f..7e534b93 100644
--- a/event_emitter.go
+++ b/event_emitter.go
@@ -4,6 +4,8 @@ import (
"math"
"reflect"
"sync"
+
+ "golang.org/x/exp/slices"
)
type EventEmitter interface {
@@ -15,44 +17,33 @@ type EventEmitter interface {
}
type (
- eventRegister struct {
- once []interface{}
- on []interface{}
- }
eventEmitter struct {
eventsMutex sync.Mutex
events map[string]*eventRegister
+ hasInit bool
+ }
+ eventRegister struct {
+ listeners []listener
+ }
+ listener struct {
+ handler interface{}
+ once bool
}
)
-func (e *eventEmitter) Emit(name string, payload ...interface{}) (handled bool) {
+func (e *eventEmitter) Emit(name string, payload ...interface{}) (hasListener bool) {
e.eventsMutex.Lock()
defer e.eventsMutex.Unlock()
- if _, ok := e.events[name]; !ok {
- return
- }
-
- if len(e.events[name].once) > 0 || len(e.events[name].on) > 0 {
- handled = true
- }
+ e.init()
- payloadV := make([]reflect.Value, 0)
-
- for _, p := range payload {
- payloadV = append(payloadV, reflect.ValueOf(p))
- }
-
- callHandlers := func(handlers []interface{}) {
- for _, handler := range handlers {
- handlerV := reflect.ValueOf(handler)
- handlerV.Call(payloadV[:int(math.Min(float64(handlerV.Type().NumIn()), float64(len(payloadV))))])
- }
+ evt, ok := e.events[name]
+ if !ok {
+ return
}
- callHandlers(e.events[name].on)
- callHandlers(e.events[name].once)
+ hasListener = evt.count() > 0
- e.events[name].once = make([]interface{}, 0)
+ evt.callHandlers(payload...)
return
}
@@ -67,60 +58,88 @@ func (e *eventEmitter) On(name string, handler interface{}) {
func (e *eventEmitter) RemoveListener(name string, handler interface{}) {
e.eventsMutex.Lock()
defer e.eventsMutex.Unlock()
+ e.init()
+
if _, ok := e.events[name]; !ok {
return
}
- handlerPtr := reflect.ValueOf(handler).Pointer()
+ e.events[name].removeHandler(handler)
+}
- onHandlers := []interface{}{}
- for idx := range e.events[name].on {
- eventPtr := reflect.ValueOf(e.events[name].on[idx]).Pointer()
- if eventPtr != handlerPtr {
- onHandlers = append(onHandlers, e.events[name].on[idx])
- }
- }
- e.events[name].on = onHandlers
+// ListenerCount count the listeners by name, count all if name is empty
+func (e *eventEmitter) ListenerCount(name string) int {
+ e.eventsMutex.Lock()
+ defer e.eventsMutex.Unlock()
+ e.init()
- onceHandlers := []interface{}{}
- for idx := range e.events[name].once {
- eventPtr := reflect.ValueOf(e.events[name].once[idx]).Pointer()
- if eventPtr != handlerPtr {
- onceHandlers = append(onceHandlers, e.events[name].once[idx])
+ if name != "" {
+ evt, ok := e.events[name]
+ if !ok {
+ return 0
}
+ return evt.count()
}
- e.events[name].once = onceHandlers
-}
-
-// ListenerCount count the listeners by name, count all if name is empty
-func (e *eventEmitter) ListenerCount(name string) int {
count := 0
- e.eventsMutex.Lock()
for key := range e.events {
- if name == "" || name == key {
- count += len(e.events[key].on) + len(e.events[key].once)
- }
+ count += e.events[key].count()
}
- e.eventsMutex.Unlock()
+
return count
}
func (e *eventEmitter) addEvent(name string, handler interface{}, once bool) {
e.eventsMutex.Lock()
+ e.init()
+
if _, ok := e.events[name]; !ok {
e.events[name] = &eventRegister{
- on: make([]interface{}, 0),
- once: make([]interface{}, 0),
+ listeners: make([]listener, 0),
}
}
- if once {
- e.events[name].once = append(e.events[name].once, handler)
- } else {
- e.events[name].on = append(e.events[name].on, handler)
- }
+ e.events[name].addHandler(handler, once)
e.eventsMutex.Unlock()
}
-func (e *eventEmitter) initEventEmitter() {
- e.events = make(map[string]*eventRegister)
+func (e *eventEmitter) init() {
+ if !e.hasInit {
+ e.events = make(map[string]*eventRegister, 0)
+ e.hasInit = true
+ }
+}
+
+func (e *eventRegister) addHandler(handler interface{}, once bool) {
+ e.listeners = append(e.listeners, listener{handler: handler, once: once})
+}
+
+func (e *eventRegister) count() int {
+ return len(e.listeners)
+}
+
+func (e *eventRegister) removeHandler(handler interface{}) {
+ handlerPtr := reflect.ValueOf(handler).Pointer()
+
+ e.listeners = slices.DeleteFunc[[]listener](e.listeners, func(l listener) bool {
+ return reflect.ValueOf(l.handler).Pointer() == handlerPtr
+ })
+}
+
+func (e *eventRegister) callHandlers(payloads ...interface{}) {
+ payloadV := make([]reflect.Value, 0)
+
+ for _, p := range payloads {
+ payloadV = append(payloadV, reflect.ValueOf(p))
+ }
+
+ handle := func(l listener) {
+ handlerV := reflect.ValueOf(l.handler)
+ handlerV.Call(payloadV[:int(math.Min(float64(handlerV.Type().NumIn()), float64(len(payloadV))))])
+ }
+
+ for _, l := range e.listeners {
+ if l.once {
+ defer e.removeHandler(l.handler)
+ }
+ handle(l)
+ }
}
diff --git a/event_emitter_test.go b/event_emitter_test.go
index e939d648..6fad9290 100644
--- a/event_emitter_test.go
+++ b/event_emitter_test.go
@@ -14,7 +14,6 @@ const (
func TestEventEmitterListenerCount(t *testing.T) {
handler := &eventEmitter{}
- handler.initEventEmitter()
wasCalled := make(chan interface{}, 1)
myHandler := func(payload ...interface{}) {
wasCalled <- payload[0]
@@ -32,7 +31,6 @@ func TestEventEmitterListenerCount(t *testing.T) {
func TestEventEmitterOn(t *testing.T) {
handler := &eventEmitter{}
- handler.initEventEmitter()
wasCalled := make(chan interface{}, 1)
require.Nil(t, handler.events[testEventName])
handler.On(testEventName, func(payload ...interface{}) {
@@ -48,7 +46,6 @@ func TestEventEmitterOn(t *testing.T) {
func TestEventEmitterOnce(t *testing.T) {
handler := &eventEmitter{}
- handler.initEventEmitter()
wasCalled := make(chan interface{}, 1)
require.Nil(t, handler.events[testEventName])
handler.Once(testEventName, func(payload ...interface{}) {
@@ -64,7 +61,6 @@ func TestEventEmitterOnce(t *testing.T) {
func TestEventEmitterRemove(t *testing.T) {
handler := &eventEmitter{}
- handler.initEventEmitter()
wasCalled := make(chan interface{}, 1)
require.Nil(t, handler.events[testEventName])
myHandler := func(payload ...interface{}) {
@@ -84,14 +80,12 @@ func TestEventEmitterRemove(t *testing.T) {
func TestEventEmitterRemoveEmpty(t *testing.T) {
handler := &eventEmitter{}
- handler.initEventEmitter()
handler.RemoveListener(testEventName, func(...interface{}) {})
require.Equal(t, 0, handler.ListenerCount(testEventName))
}
func TestEventEmitterRemoveKeepExisting(t *testing.T) {
handler := &eventEmitter{}
- handler.initEventEmitter()
handler.On(testEventName, func(...interface{}) {})
handler.Once(testEventName, func(...interface{}) {})
handler.RemoveListener("abc123", func(...interface{}) {})
@@ -101,7 +95,6 @@ func TestEventEmitterRemoveKeepExisting(t *testing.T) {
func TestEventEmitterOnLessArgsAcceptingReceiver(t *testing.T) {
handler := &eventEmitter{}
- handler.initEventEmitter()
wasCalled := make(chan bool, 1)
require.Nil(t, handler.events[testEventName])
handler.Once(testEventName, func(ev ...interface{}) {
diff --git a/go.mod b/go.mod
index c5a8fdb2..c5c96c6b 100644
--- a/go.mod
+++ b/go.mod
@@ -12,6 +12,7 @@ require (
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.17.0
go.uber.org/multierr v1.11.0
+ golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc
)
require (
diff --git a/go.sum b/go.sum
index e374082d..2bdc5198 100644
--- a/go.sum
+++ b/go.sum
@@ -43,6 +43,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc h1:ao2WRsKSzW6KuUY9IWPwWahcHCgR0s52IfwutMfEbdM=
+golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
diff --git a/local_utils.go b/local_utils.go
index b3c06697..2a882bfe 100644
--- a/local_utils.go
+++ b/local_utils.go
@@ -138,7 +138,7 @@ func (l *localUtilsImpl) TraceDiscarded(stacksId string) error {
return err
}
-func (l *localUtilsImpl) AddStackToTracingNoReply(id int, stack []map[string]interface{}) {
+func (l *localUtilsImpl) AddStackToTracingNoReply(id uint32, stack []map[string]interface{}) {
l.channel.SendNoReply("addStackToTracingNoReply", map[string]interface{}{
"callData": map[string]interface{}{
"id": id,
diff --git a/tests/browser_context_test.go b/tests/browser_context_test.go
index 8e75de7f..aa70c244 100644
--- a/tests/browser_context_test.go
+++ b/tests/browser_context_test.go
@@ -520,3 +520,18 @@ func TestPageErrorEventShouldWork(t *testing.T) {
require.Equal(t, page, weberror.Page())
require.ErrorContains(t, weberror.Error(), "boom")
}
+
+func TestBrowserContextOnResponse(t *testing.T) {
+ BeforeEach(t)
+ defer AfterEach(t)
+ responseChan := make(chan playwright.Response, 1)
+ context.OnResponse(func(response playwright.Response) {
+ responseChan <- response
+ })
+ _, err := page.Goto(fmt.Sprintf("%s/title.html", server.PREFIX))
+ require.NoError(t, err)
+ response := <-responseChan
+ body, err := response.Body()
+ require.NoError(t, err)
+ require.Equal(t, "
Woof-Woof\n", string(body))
+}
diff --git a/tests/browser_test.go b/tests/browser_test.go
index 868e5af4..731cdf31 100644
--- a/tests/browser_test.go
+++ b/tests/browser_test.go
@@ -14,6 +14,12 @@ func TestBrowserIsConnected(t *testing.T) {
require.True(t, browser.IsConnected())
}
+func TestBrowserShouldReturnBrowserType(t *testing.T) {
+ BeforeEach(t)
+ defer AfterEach(t)
+ require.Equal(t, browserType, browser.BrowserType())
+}
+
func TestBrowserVersion(t *testing.T) {
BeforeEach(t)
defer AfterEach(t)
@@ -23,7 +29,11 @@ func TestBrowserVersion(t *testing.T) {
func TestBrowserNewContext(t *testing.T) {
BeforeEach(t)
defer AfterEach(t)
- require.Equal(t, 1, len(context.Pages()))
+ context2, err := browser.NewContext()
+ require.NoError(t, err)
+ require.Equal(t, 2, len(browser.Contexts()))
+ require.NoError(t, context2.Close())
+ require.Equal(t, 1, len(browser.Contexts()))
}
func TestBrowserNewContextWithExtraHTTPHeaders(t *testing.T) {
diff --git a/tests/browser_type_test.go b/tests/browser_type_test.go
index 80d427aa..214ee3bf 100644
--- a/tests/browser_type_test.go
+++ b/tests/browser_type_test.go
@@ -126,8 +126,8 @@ func TestBrowserTypeConnectShouldEmitDisconnectedEvent(t *testing.T) {
defer AfterEach(t)
remoteServer, err := newRemoteServer()
require.NoError(t, err)
- disconnected1 := newSyncSlice()
- disconnected2 := newSyncSlice()
+ disconnected1 := newSyncSlice[bool]()
+ disconnected2 := newSyncSlice[bool]()
browser1, err := browserType.Connect(remoteServer.url)
require.NoError(t, err)
require.NotNil(t, browser1)
diff --git a/tests/console_message_test.go b/tests/console_message_test.go
index 475c9bfd..bb26302d 100644
--- a/tests/console_message_test.go
+++ b/tests/console_message_test.go
@@ -85,28 +85,28 @@ func TestConsoleShouldWorkForDifferentConsoleAPICalls(t *testing.T) {
console.error('calling console.error');
console.log(Promise.resolve('should not wait until resolved!'));
}`)
- messages := ChanToSlice(messagesChan, 6).([]playwright.ConsoleMessage)
+ messages := ChanToSlice(messagesChan, 6)
require.NoError(t, err)
- require.Equal(t, []interface{}{
+ require.Equal(t, []string{
"timeEnd",
"trace",
"dir",
"warning",
"error",
"log",
- }, Map(messages, func(msg interface{}) interface{} {
- return msg.(playwright.ConsoleMessage).Type()
+ }, Map(messages, func(msg playwright.ConsoleMessage) string {
+ return msg.Type()
}))
require.Contains(t, messages[0].Text(), "calling console.time")
- require.Equal(t, []interface{}{
+ require.Equal(t, []string{
"calling console.trace",
"calling console.dir",
"calling console.warn",
"calling console.error",
"Promise",
- }, Map(messages[1:], func(msg interface{}) interface{} {
- return msg.(playwright.ConsoleMessage).Text()
+ }, Map(messages[1:], func(msg playwright.ConsoleMessage) string {
+ return msg.Text()
}))
}
diff --git a/tests/frame_locator_test.go b/tests/frame_locator_test.go
index f3020095..1a72c6a7 100644
--- a/tests/frame_locator_test.go
+++ b/tests/frame_locator_test.go
@@ -79,15 +79,32 @@ func routeAmbiguous(t *testing.T, page playwright.Page) {
}
func TestFrameLocatorFirst(t *testing.T) {
- BeforeEach(t)
- defer AfterEach(t)
- routeAmbiguous(t, page)
- _, err := page.Goto(server.EMPTY_PAGE)
- require.NoError(t, err)
- innerText, err := page.Locator("body").FrameLocator("iframe").First().GetByRole("button").InnerText()
- require.NoError(t, err)
- require.Equal(t, "Hello from iframe-1.html", innerText)
+ t.Run("basic", func(t *testing.T) {
+
+ BeforeEach(t)
+ defer AfterEach(t)
+ routeAmbiguous(t, page)
+ _, err := page.Goto(server.EMPTY_PAGE)
+ require.NoError(t, err)
+
+ innerText, err := page.Locator("body").FrameLocator("iframe").First().GetByRole("button").InnerText()
+ require.NoError(t, err)
+ require.Equal(t, "Hello from iframe-1.html", innerText)
+ })
+
+ t.Run("ambiguous", func(t *testing.T) {
+ BeforeEach(t)
+ defer AfterEach(t)
+ routeAmbiguous(t, page)
+ _, err := page.Goto(server.EMPTY_PAGE)
+ require.NoError(t, err)
+
+ innerText, err := page.Locator("body").FrameLocator("iframe").Nth(1).Locator("button").InnerText()
+ require.NoError(t, err)
+ require.Equal(t, "Hello from iframe-2.html", innerText)
+ })
+
}
func TestFrameLocatorNth(t *testing.T) {
diff --git a/tests/frame_test.go b/tests/frame_test.go
index c08dd15e..46c87a74 100644
--- a/tests/frame_test.go
+++ b/tests/frame_test.go
@@ -216,3 +216,18 @@ func TestFrameParent(t *testing.T) {
require.Equal(t, page.MainFrame(), frames[1].ParentFrame())
require.Equal(t, page.MainFrame(), frames[2].ParentFrame())
}
+
+func TestFrameShouldHandleNestedFrames(t *testing.T) {
+ BeforeEach(t)
+ defer AfterEach(t)
+ _, err := page.Goto(server.PREFIX + "/frames/nested-frames.html")
+ require.NoError(t, err)
+ dump := utils.DumpFrames(page.MainFrame(), "")
+ require.Equal(t, []string{
+ "http://localhost:/frames/nested-frames.html",
+ " http://localhost:/frames/frame.html (aframe)",
+ " http://localhost:/frames/two-frames.html (2frames)",
+ " http://localhost:/frames/frame.html (dos)",
+ " http://localhost:/frames/frame.html (uno)",
+ }, dump)
+}
diff --git a/tests/helper_test.go b/tests/helper_test.go
index 4e93b00c..40eba382 100644
--- a/tests/helper_test.go
+++ b/tests/helper_test.go
@@ -13,7 +13,7 @@ import (
"net/http/httptest"
"os"
"path/filepath"
- "reflect"
+ "sort"
"strings"
"sync"
"testing"
@@ -209,52 +209,43 @@ func (s *testServer) WaitForRequestChan(path string) <-chan *http.Request {
return channel
}
-func Map(vs interface{}, f func(interface{}) interface{}) []interface{} {
- v := reflect.ValueOf(vs)
- vsm := make([]interface{}, v.Len())
- for i := 0; i < v.Len(); i++ {
- vsm[i] = f(v.Index(i).Interface())
+func Map[T any, R any](vs []T, f func(T) R) []R {
+ vsm := make([]R, len(vs))
+ for i, v := range vs {
+ vsm[i] = f(v)
}
return vsm
}
-// ChanToSlice reads all data from ch (which must be a chan), returning a
-// slice of the data. If ch is a 'T chan' then the return value is of type
-// []T inside the returned interface.
-// A typical call would be sl := ChanToSlice(ch).([]int)
-func ChanToSlice(ch interface{}, amount int) interface{} {
- chv := reflect.ValueOf(ch)
- slv := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(ch).Elem()), 0, 0)
+// ChanToSlice reads amount of values from the channel, returns them as a slice
+func ChanToSlice[T any](ch chan T, amount int) []T {
+ data := make([]T, 0)
for i := 0; i < amount; i++ {
- v, ok := chv.Recv()
- if !ok {
- return slv.Interface()
- }
- slv = reflect.Append(slv, v)
+ data = append(data, <-ch)
}
- return slv.Interface()
+ return data
}
-type syncSlice struct {
+type syncSlice[T any] struct {
sync.Mutex
- slice []interface{}
+ slice []T
}
-func (s *syncSlice) Append(v interface{}) {
+func (s *syncSlice[T]) Append(v T) {
s.Lock()
+ defer s.Unlock()
s.slice = append(s.slice, v)
- s.Unlock()
}
-func (s *syncSlice) Get() interface{} {
+func (s *syncSlice[T]) Get() []T {
s.Lock()
defer s.Unlock()
return s.slice
}
-func newSyncSlice() *syncSlice {
- return &syncSlice{
- slice: make([]interface{}, 0),
+func newSyncSlice[T any]() *syncSlice[T] {
+ return &syncSlice[T]{
+ slice: make([]T, 0),
}
}
@@ -288,6 +279,24 @@ func (t *testUtils) DetachFrame(page playwright.Page, frameId string) error {
return err
}
+func (tu *testUtils) DumpFrames(frame playwright.Frame, indentation string) []string {
+ desc := strings.Replace(frame.URL(), server.PREFIX, "http://localhost:", 1)
+ if frame.Name() != "" {
+ desc = fmt.Sprintf("%s (%s)", desc, frame.Name())
+ }
+ result := []string{
+ indentation + desc,
+ }
+ sortedFrames := frame.ChildFrames()
+ sort.SliceStable(sortedFrames, func(i, j int) bool {
+ return (sortedFrames[i].URL() + sortedFrames[i].Name()) < (sortedFrames[j].URL() + sortedFrames[j].Name())
+ })
+ for _, f := range sortedFrames {
+ result = append(result, tu.DumpFrames(f, " "+indentation)...)
+ }
+ return result
+}
+
func (tu *testUtils) VerifyViewport(t *testing.T, page playwright.Page, width, height int) {
require.Equal(t, page.ViewportSize().Width, width)
require.Equal(t, page.ViewportSize().Height, height)
diff --git a/tests/route_test.go b/tests/route_test.go
index 8c74697a..5b431175 100644
--- a/tests/route_test.go
+++ b/tests/route_test.go
@@ -174,7 +174,7 @@ func TestRouteFulfillPath(t *testing.T) {
func TestRequestFinished(t *testing.T) {
BeforeEach(t)
defer AfterEach(t)
- eventsStorage := newSyncSlice()
+ eventsStorage := newSyncSlice[string]()
var request playwright.Request
page.Once("request", func(r playwright.Request) {
request = r
@@ -187,7 +187,7 @@ func TestRequestFinished(t *testing.T) {
require.NoError(t, err)
require.NoError(t, response.Finished())
eventsStorage.Append("requestfinished")
- require.Equal(t, []interface{}{"request", "response", "requestfinished"}, eventsStorage.Get())
+ require.Equal(t, []string{"request", "response", "requestfinished"}, eventsStorage.Get())
require.Equal(t, response.Request(), request)
require.Equal(t, response.Frame(), page.MainFrame())
}
diff --git a/video.go b/video.go
index 85b5b128..a57b61ab 100644
--- a/video.go
+++ b/video.go
@@ -9,6 +9,7 @@ type videoImpl struct {
page *pageImpl
artifact *artifactImpl
artifactChan chan *artifactImpl
+ done chan struct{}
closeOnce sync.Once
isRemote bool
}
@@ -49,9 +50,7 @@ func (v *videoImpl) artifactReady(artifact *artifactImpl) {
func (v *videoImpl) pageClosed(p Page) {
v.closeOnce.Do(func() {
- if v.artifactChan != nil {
- close(v.artifactChan)
- }
+ close(v.done)
})
}
@@ -65,19 +64,31 @@ func (v *videoImpl) getArtifact() {
v.pageClosed(v.page)
}
}
- artifact := <-v.artifactChan
- if artifact != nil {
- v.artifact = artifact
+ select {
+ case artifact := <-v.artifactChan:
+ if artifact != nil {
+ v.artifact = artifact
+ }
+ case <-v.done: // page closed
+ select { // make sure get artifact if it's ready before page closed
+ case artifact := <-v.artifactChan:
+ if artifact != nil {
+ v.artifact = artifact
+ }
+ default:
+ }
}
}
func newVideo(page *pageImpl) *videoImpl {
video := &videoImpl{
- page: page,
- isRemote: page.connection.isRemote,
+ page: page,
+ artifactChan: make(chan *artifactImpl, 1),
+ done: make(chan struct{}, 1),
+ isRemote: page.connection.isRemote,
}
- video.artifactChan = make(chan *artifactImpl, 1)
- if page.IsClosed() {
+
+ if page.isClosed {
video.pageClosed(page)
} else {
page.OnClose(video.pageClosed)
diff --git a/waiter_test.go b/waiter_test.go
index bf4b3c5b..0afed58d 100644
--- a/waiter_test.go
+++ b/waiter_test.go
@@ -17,7 +17,6 @@ const (
func TestWaiterWaitForEvent(t *testing.T) {
timeout := 500.0
emitter := &eventEmitter{}
- emitter.initEventEmitter()
waiter := newWaiter().WithTimeout(timeout)
_, err := waiter.Wait()
require.Error(t, err)
@@ -35,7 +34,6 @@ func TestWaiterWaitForEvent(t *testing.T) {
func TestWaiterWaitForEventWithPredicate(t *testing.T) {
timeout := 500.0
emitter := &eventEmitter{}
- emitter.initEventEmitter()
waiter := newWaiter().WithTimeout(timeout)
waiter.WaitForEvent(emitter, testEventNameFoobar, func(payload interface{}) bool {
content, ok := payload.(string)
@@ -57,7 +55,6 @@ func TestWaiterWaitForEventWithPredicate(t *testing.T) {
func TestWaiterRejectOnTimeout(t *testing.T) {
timeout := 300.0
emitter := &eventEmitter{}
- emitter.initEventEmitter()
waiter := newWaiter().WithTimeout(timeout)
waiter.WaitForEvent(emitter, testEventNameFoobar, nil)
go func() {
@@ -73,7 +70,6 @@ func TestWaiterRejectOnEvent(t *testing.T) {
errCause := fmt.Errorf("reject on event")
errPredicate := fmt.Errorf("payload on event")
emitter := &eventEmitter{}
- emitter.initEventEmitter()
waiter := newWaiter().RejectOnEvent(emitter, testEventNameReject, errCause)
waiter.RejectOnEvent(emitter, testEventNameFoobar, errPredicate, func(payload interface{}) bool {
content, ok := payload.(string)
@@ -99,7 +95,6 @@ func TestWaiterRejectOnEventWithPredicate(t *testing.T) {
errCause := fmt.Errorf("reject on event")
errPredicate := fmt.Errorf("payload on event")
emitter := &eventEmitter{}
- emitter.initEventEmitter()
waiter := newWaiter().RejectOnEvent(emitter, testEventNameReject, errCause)
waiter.RejectOnEvent(emitter, testEventNameFoobar, errPredicate, func(payload interface{}) bool {
content, ok := payload.(string)
@@ -123,7 +118,6 @@ func TestWaiterRejectOnEventWithPredicate(t *testing.T) {
func TestWaiterReturnErrorWhenMisuse(t *testing.T) {
emitter := &eventEmitter{}
- emitter.initEventEmitter()
waiter := newWaiter()
waiter.WaitForEvent(emitter, testEventNameFoobar, nil)
waiter.WithTimeout(500)