Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions handlers/useSession.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import (
)

type UseSession struct {
Pool *pool.Pool
Pool *pool.Pool
Cache *pool.Cache
}

func (h *UseSession) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
Expand All @@ -24,12 +25,17 @@ func (h *UseSession) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
return
}
sessionID := re.FindStringSubmatch(r.URL.Path)[1]
targetNode, err := h.Pool.GetNodeBySessionID(sessionID)
if err != nil {
errorMessage := "session " + sessionID + " not found in node pool: " + err.Error()
log.Infof(errorMessage)
http.Error(rw, errorMessage, http.StatusNotFound)
return
targetNode, ok := h.Cache.Get(sessionID)
var err error
if !ok {
targetNode, err = h.Pool.GetNodeBySessionID(sessionID)
if err != nil {
errorMessage := "session " + sessionID + " not found in node pool: " + err.Error()
log.Infof(errorMessage)
http.Error(rw, errorMessage, http.StatusNotFound)
return
}
h.Cache.Set(sessionID, targetNode)
}

proxy := httputil.NewSingleHostReverseProxy(&url.URL{
Expand Down
11 changes: 10 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,22 @@ func main() {
}
}()

cache := pool.NewCache(time.Minute * 10) // todo: move to config

go func() {
for {
cache.CleanUp()
time.Sleep(time.Minute) // todo: move to config
}
}()

m := middleware.NewLogMiddleware(statsdClient)
http.Handle("/wd/hub/session", m.Log(&handlers.CreateSession{Pool: poolInstance, ClientFactory: clientFactory})) //selenium
http.Handle("/session", m.Log(&handlers.CreateSession{Pool: poolInstance, ClientFactory: clientFactory})) //wda
http.Handle("/grid/register", m.Log(&handlers.RegisterNode{Pool: poolInstance}))
http.Handle("/grid/api/proxy", &handlers.APIProxy{Pool: poolInstance})
http.HandleFunc("/_info", heartbeat)
http.Handle("/", m.Log(&handlers.UseSession{Pool: poolInstance}))
http.Handle("/", m.Log(&handlers.UseSession{Pool: poolInstance, Cache: cache}))

server := &http.Server{Addr: fmt.Sprintf(":%v", cfg.Grid.Port)}
serverError := make(chan error)
Expand Down
55 changes: 55 additions & 0 deletions pool/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package pool

import (
"sync"
"time"
)

// Cache no thread safe
type Cache struct {
storage map[string]*cacheEntry
expirationTime time.Duration
sync.RWMutex
}

type cacheEntry struct {
node *Node
created time.Time
}

func NewCache(expirationTime time.Duration) *Cache {
return &Cache{
storage: make(map[string]*cacheEntry),
expirationTime: expirationTime,
}
}

func (c *Cache) Set(key string, node *Node) {
c.Lock()
c.storage[key] = &cacheEntry{
node: node,
created: time.Now(),
}
c.Unlock()
}

func (c *Cache) Get(key string) (node *Node, ok bool) {
c.RLock()
entry, ok := c.storage[key]
if !ok {
c.RUnlock()
return nil, false
}
c.RUnlock()
return entry.node, true
}

func (c *Cache) CleanUp() {
c.Lock()
for i, _ := range c.storage {
if time.Since(c.storage[i].created) > c.expirationTime {
delete(c.storage, i)
}
}
c.Unlock()
}
49 changes: 49 additions & 0 deletions pool/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package pool

import (
"testing"
"github.com/stretchr/testify/assert"
"time"
)

func TestNewCache(t *testing.T) {
c := NewCache(time.Second)
assert.NotNil(t, c)
}

func TestCache_Set_ReturnsNode(t *testing.T) {
c := NewCache(time.Second)
key := "1"
node := new(Node)
c.Set(key, node)
assert.Equal(t, node, c.storage[key].node)
}

func TestCache_Get_NodeExists_ReturnsNodeTrue(t *testing.T) {
c := NewCache(time.Second)
key := "1"
nodeExp := new(Node)
c.storage[key] = &cacheEntry{node: nodeExp, created: time.Now()}
node, ok := c.Get(key)
assert.Equal(t, nodeExp, node)
assert.True(t, ok)
}

func TestCache_Get_NodeNotExists_ReturnsNilFalse(t *testing.T) {
c := NewCache(time.Second)
key := "1"
node, ok := c.Get(key)
assert.Nil(t, node)
assert.False(t, ok)
}

func TestCache_CleanUp_ExpiredPart_RemoveExpired(t *testing.T) {
c := NewCache(time.Minute)
nodeExp := new(Node)
c.storage["1"] = &cacheEntry{node: nodeExp, created: time.Now().Add(-time.Hour)}
c.storage["2"] = &cacheEntry{node: nodeExp, created: time.Now().Add(-time.Hour)}
c.storage["3"] = &cacheEntry{node: nodeExp, created: time.Now().Add(time.Hour)}
c.storage["4"] = &cacheEntry{node: nodeExp, created: time.Now().Add(time.Hour)}
c.CleanUp()
assert.Len(t, c.storage, 2)
}
4 changes: 3 additions & 1 deletion testing/webdriver-node-mock/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"time"
)

var constResponse = RandStringRunes(10000)

// status return current status
func status(rw http.ResponseWriter, r *http.Request) {
sessions := &jsonwire.Message{}
Expand Down Expand Up @@ -108,7 +110,7 @@ func useSession(rw http.ResponseWriter, r *http.Request) {
case parsedUrl[2] == "" && r.Method == http.MethodDelete: // session closed by client
currentSessionID = ""
default:
responseMessage.Value = RandStringRunes(10000)
responseMessage.Value = constResponse
}
err := json.NewEncoder(rw).Encode(responseMessage)
if err != nil {
Expand Down