diff --git a/main.go b/main.go index 0313015..673e2d5 100644 --- a/main.go +++ b/main.go @@ -3,8 +3,10 @@ package main import ( "bufio" "context" + "errors" "flag" "fmt" + "io" "log" "net" "net/url" @@ -135,8 +137,9 @@ func main() { deathNote := sync.Map{} - connectionAccepted := make(chan net.Addr) - connectionLost := make(chan net.Addr) + // Buffered so we don't block the main process. + connectionAccepted := make(chan net.Addr, 10) + connectionLost := make(chan net.Addr, 10) go processRequests(&deathNote, connectionAccepted, connectionLost) @@ -206,8 +209,10 @@ func processRequests(deathNote *sync.Map, connectionAccepted chan<- net.Addr, co } if err != nil { - log.Println(err) - break + if !errors.Is(err, io.EOF) { + log.Println(err) + } + return } } }(conn) @@ -216,45 +221,28 @@ func processRequests(deathNote *sync.Map, connectionAccepted chan<- net.Addr, co func waitForPruneCondition(ctx context.Context, connectionAccepted <-chan net.Addr, connectionLost <-chan net.Addr) { connectionCount := 0 - never := make(chan time.Time, 1) - defer close(never) - - handleConnectionAccepted := func(addr net.Addr) { - log.Printf("New client connected: %s", addr) - connectionCount++ - } - - select { - case <-time.After(connectionTimeout): - panic("Timed out waiting for the first connection") - case addr := <-connectionAccepted: - handleConnectionAccepted(addr) - case <-ctx.Done(): - log.Println("Signal received") - return - } - + timer := time.NewTimer(connectionTimeout) for { - var noConnectionTimeout <-chan time.Time - if connectionCount == 0 { - noConnectionTimeout = time.After(reconnectionTimeout) - } else { - noConnectionTimeout = never - } - select { case addr := <-connectionAccepted: - handleConnectionAccepted(addr) - break + log.Printf("New client connected: %s", addr) + connectionCount++ + if connectionCount == 1 { + if !timer.Stop() { + <-timer.C + } + } case addr := <-connectionLost: log.Printf("Client disconnected: %s", addr.String()) connectionCount-- - break + if connectionCount == 0 { + timer.Reset(reconnectionTimeout) + } case <-ctx.Done(): log.Println("Signal received") return - case <-noConnectionTimeout: - log.Println("Timed out waiting for re-connection") + case <-timer.C: + log.Println("Timeout waiting for connection") return } } diff --git a/main_test.go b/main_test.go index 4fa08e1..f5a7a9e 100644 --- a/main_test.go +++ b/main_test.go @@ -6,10 +6,10 @@ import ( "context" "fmt" "io" + "log" "net" "os" "path/filepath" - "strings" "sync" "testing" "time" @@ -64,26 +64,27 @@ func TestInitialTimeout(t *testing.T) { // reset connectionTimeout connectionTimeout = testConnectionTimeout + origWriter := log.Default().Writer() + defer func() { + log.SetOutput(origWriter) + }() + var buf bytes.Buffer + log.SetOutput(&buf) + acc := make(chan net.Addr) lost := make(chan net.Addr) - done := make(chan string) go func() { - defer func() { - err := recover().(string) - done <- err - }() waitForPruneCondition(context.Background(), acc, lost) + done <- buf.String() }() select { case p := <-done: - if !strings.Contains(p, "first connection") { - t.Fail() - } + require.Contains(t, p, "Timeout waiting for connection") case <-time.After(7 * time.Second): - t.Fail() + t.Fatal("Timeout waiting prune condition") } }