From 51e62b670c8b2543ceddfd2b2dd15bec5955f2fd Mon Sep 17 00:00:00 2001 From: Kroese Date: Sun, 10 Dec 2023 06:11:50 +0100 Subject: [PATCH] feat: Add timeout parameter --- src/main.go | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/src/main.go b/src/main.go index e49f7c7..b2f0c20 100644 --- a/src/main.go +++ b/src/main.go @@ -81,7 +81,7 @@ var GuestSN = flag.String("guestsn", "0000000000000", "Guest serial number") var GuestCPU_Arch = flag.String("cpu_arch", "QEMU, Virtual CPU, X86_64", "CPU arch") var ApiPort = flag.String("api", ":2210", "API port") -var ApiTimeout = flag.Int("timeout", 45, "API timeout") +var ApiTimeout = flag.Int("timeout", 10, "Default timeout") var ListenAddr = flag.String("addr", "0.0.0.0:12345", "Listen address") func main() { @@ -308,13 +308,31 @@ func read(w http.ResponseWriter, r *http.Request) { defer Writer.Unlock() query := r.URL.Query() - commandID, err := strconv.Atoi(query.Get("command")) + cmd := query.Get("command") + timeout := query.Get("timeout") + wait := time.Duration(*ApiTimeout) + + if len(strings.TrimSpace(cmd)) == 0 { + fail(w, "No command specified") + return + } + + commandID, err := strconv.Atoi(cmd) if err != nil || commandID < 1 { - fail(w, fmt.Sprintf("Failed to parse command %s", query.Get("command"))) + fail(w, fmt.Sprintf("Failed to parse command: %s", cmd)) return } + if len(strings.TrimSpace(timeout)) > 0 { + duration, err := strconv.Atoi(timeout) + if err != nil || duration < 1 { + fail(w, fmt.Sprintf("Failed to parse timeout: %s", timeout)) + return + } + wait = time.Duration(duration) + } + if Connection == nil || Chan == nil { fail(w, "No connection to guest") return @@ -339,7 +357,7 @@ func read(w http.ResponseWriter, r *http.Request) { select { case res := <-Chan: resp = res - case <-time.After(time.Duration(*ApiTimeout) * time.Second): + case <-time.After(wait * time.Second): atomic.StoreInt32(&WaitingFor, 0) fail(w, fmt.Sprintf("Timeout while reading command %d from guest", commandID)) return @@ -374,10 +392,17 @@ func write(w http.ResponseWriter, r *http.Request) { } query := r.URL.Query() - commandID, err := strconv.Atoi(query.Get("command")) + cmd := query.Get("command") + + if len(strings.TrimSpace(cmd)) == 0 { + fail(w, "No command specified") + return + } + + commandID, err := strconv.Atoi(cmd) if err != nil || commandID < 1 { - fail(w, fmt.Sprintf("Failed to parse command %s", query.Get("command"))) + fail(w, fmt.Sprintf("Failed to parse command: %s", cmd)) return }