Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ package osquery
import (
"context"
"fmt"
"os"
"os/signal"
"sync"
"syscall"
"time"

"git.apache.org/thrift.git/lib/go/thrift"
Expand Down Expand Up @@ -52,6 +49,7 @@ type ExtensionManagerServer struct {
timeout time.Duration
pingInterval time.Duration // How often to ping osquery server
mutex sync.Mutex
started bool // Used to ensure tests wait until the server is actually started
}

// validRegistryNames contains the allowable RegistryName() values. If a plugin
Expand Down Expand Up @@ -167,6 +165,9 @@ func (s *ExtensionManagerServer) Start() error {

s.server = thrift.NewTSimpleServer2(processor, s.transport)
server = s.server

s.started = true

return nil
}()

Expand All @@ -177,23 +178,14 @@ func (s *ExtensionManagerServer) Start() error {
return server.Serve()
}

// Run starts the extension manager and runs until an an interrupt
// signal is received.
// Run will call Shutdown before exiting.
// Run starts the extension manager and runs until osquery calls for a shutdown
// or the osquery instance goes away.
func (s *ExtensionManagerServer) Run() error {
errc := make(chan error)
go func() {
errc <- s.Start()
}()

// Interrupt handler.
go func() {
sig := make(chan os.Signal)
signal.Notify(sig, os.Interrupt, os.Kill, syscall.SIGTERM)
<-sig
errc <- nil
}()

// Watch for the osquery process going away. If so, initiate shutdown.
go func() {
for {
Expand Down Expand Up @@ -265,5 +257,19 @@ func (s *ExtensionManagerServer) Shutdown() error {
server.Stop()
}()
}

return nil
}

// Useful for testing
func (s *ExtensionManagerServer) waitStarted() {
for {
s.mutex.Lock()
started := s.started
s.mutex.Unlock()
if started {
time.Sleep(10 * time.Millisecond)
break
}
}
}
44 changes: 41 additions & 3 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,17 @@ func testShutdownDeadlock(t *testing.T) {
},
}
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}

wait := sync.WaitGroup{}

wait.Add(1)
go func() {
err := server.Start()
require.Nil(t, err)
wait.Done()
}()
// Sleep long enough for server to start listening on socket
time.Sleep(500 * time.Millisecond)
// Wait for server to be set up
server.waitStarted()

// Create a raw client to access the shutdown method that is not
// usually exposed.
Expand All @@ -127,7 +132,6 @@ func testShutdownDeadlock(t *testing.T) {

// Simultaneously call shutdown through a request from the client and
// directly on the server object.
wait := sync.WaitGroup{}
wait.Add(1)
go func() {
defer wait.Done()
Expand All @@ -148,6 +152,40 @@ func testShutdownDeadlock(t *testing.T) {
close(completed)
}()

// either indicate successful shutdown, or fatal the test because it
// hung
select {
case <-completed:
// Success. Do nothing.
case <-time.After(5 * time.Second):
t.Fatal("hung on shutdown")
}
}

func TestShutdownBasic(t *testing.T) {
tempPath, err := ioutil.TempFile("", "")
require.Nil(t, err)
defer os.Remove(tempPath.Name())

retUUID := osquery.ExtensionRouteUUID(0)
mock := &MockExtensionManager{
RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
return &osquery.ExtensionStatus{Code: 0, UUID: retUUID}, nil
},
}
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}

completed := make(chan struct{})
go func() {
err := server.Start()
require.NoError(t, err)
close(completed)
}()

server.waitStarted()
err = server.Shutdown()
require.NoError(t, err)

// Either indicate successful shutdown, or fatal the test because it
// hung
select {
Expand Down