diff --git a/server.go b/server.go index 949c4bf..a9c8773 100644 --- a/server.go +++ b/server.go @@ -3,10 +3,7 @@ package osquery import ( "context" "fmt" - "os" - "os/signal" "sync" - "syscall" "time" "git.apache.org/thrift.git/lib/go/thrift" @@ -52,6 +49,7 @@ type ExtensionManagerServer struct { timeout time.Duration pingInterval time.Duration // How often to ping osquery server mutex sync.Mutex + started bool // Used to ensure tests wait until the server is actually started } // validRegistryNames contains the allowable RegistryName() values. If a plugin @@ -167,6 +165,9 @@ func (s *ExtensionManagerServer) Start() error { s.server = thrift.NewTSimpleServer2(processor, s.transport) server = s.server + + s.started = true + return nil }() @@ -177,23 +178,14 @@ func (s *ExtensionManagerServer) Start() error { return server.Serve() } -// Run starts the extension manager and runs until an an interrupt -// signal is received. -// Run will call Shutdown before exiting. +// Run starts the extension manager and runs until osquery calls for a shutdown +// or the osquery instance goes away. func (s *ExtensionManagerServer) Run() error { errc := make(chan error) go func() { errc <- s.Start() }() - // Interrupt handler. - go func() { - sig := make(chan os.Signal) - signal.Notify(sig, os.Interrupt, os.Kill, syscall.SIGTERM) - <-sig - errc <- nil - }() - // Watch for the osquery process going away. If so, initiate shutdown. go func() { for { @@ -265,5 +257,19 @@ func (s *ExtensionManagerServer) Shutdown() error { server.Stop() }() } + return nil } + +// Useful for testing +func (s *ExtensionManagerServer) waitStarted() { + for { + s.mutex.Lock() + started := s.started + s.mutex.Unlock() + if started { + time.Sleep(10 * time.Millisecond) + break + } + } +} diff --git a/server_test.go b/server_test.go index fa83c1d..db8b0bb 100644 --- a/server_test.go +++ b/server_test.go @@ -106,12 +106,17 @@ func testShutdownDeadlock(t *testing.T) { }, } server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()} + + wait := sync.WaitGroup{} + + wait.Add(1) go func() { err := server.Start() require.Nil(t, err) + wait.Done() }() - // Sleep long enough for server to start listening on socket - time.Sleep(500 * time.Millisecond) + // Wait for server to be set up + server.waitStarted() // Create a raw client to access the shutdown method that is not // usually exposed. @@ -127,7 +132,6 @@ func testShutdownDeadlock(t *testing.T) { // Simultaneously call shutdown through a request from the client and // directly on the server object. - wait := sync.WaitGroup{} wait.Add(1) go func() { defer wait.Done() @@ -148,6 +152,40 @@ func testShutdownDeadlock(t *testing.T) { close(completed) }() + // either indicate successful shutdown, or fatal the test because it + // hung + select { + case <-completed: + // Success. Do nothing. + case <-time.After(5 * time.Second): + t.Fatal("hung on shutdown") + } +} + +func TestShutdownBasic(t *testing.T) { + tempPath, err := ioutil.TempFile("", "") + require.Nil(t, err) + defer os.Remove(tempPath.Name()) + + retUUID := osquery.ExtensionRouteUUID(0) + mock := &MockExtensionManager{ + RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) { + return &osquery.ExtensionStatus{Code: 0, UUID: retUUID}, nil + }, + } + server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()} + + completed := make(chan struct{}) + go func() { + err := server.Start() + require.NoError(t, err) + close(completed) + }() + + server.waitStarted() + err = server.Shutdown() + require.NoError(t, err) + // Either indicate successful shutdown, or fatal the test because it // hung select {