diff --git a/cmd/run.go b/cmd/run.go index a71b02f..914d736 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -6,7 +6,6 @@ import ( "log" "os" "strings" - "sync" "time" "github.com/aws/aws-sdk-go/aws" @@ -54,7 +53,7 @@ func init() { func runRun(cmd *cobra.Command, args []string) { l := len(args) - if l < 1 || l > 1 { // TODO: run interactive mode if no argument is given + if l != 1 { // TODO: run interactive mode if no argument is given cmd.Help() return } @@ -77,48 +76,48 @@ func runRun(cmd *cobra.Command, args []string) { stmts := strings.Split(args[0], ";") // Create channels - resultCh := make(chan *exec.Result) - errCh := make(chan error) - doneCh := make(chan struct{}) - var wg sync.WaitGroup + l = len(stmts) + resultChs := make([]chan *exec.Result, 0, l) + errChs := make([]chan error, 0, l) // Run each statement concurrently using goroutine for _, stmt := range stmts { - query := stmt // capture locally - if strings.TrimSpace(query) == "" { + if strings.TrimSpace(stmt) == "" { continue // Skip empty statements } - wg.Add(1) - go runQuery(client, query, resultCh, errCh, &wg) - } - // Monitoring goroutine to notify that all the query executions have finished - go func() { - wg.Wait() - doneCh <- struct{}{} - }() + resultCh := make(chan *exec.Result) + errCh := make(chan error) + go runQuery(client, stmt, resultCh, errCh) + + resultChs = append(resultChs, resultCh) + errChs = append(errChs, errCh) + } fmt.Print("Running query") - // TODO: arrange results in the order of the original statements tick := time.Tick(tickInterval) - for { - select { - case r := <-resultCh: - fmt.Print("\n") - print.NewTable(os.Stdout).Print(r) - case e := <-errCh: - fmt.Print("\n") - fmt.Fprintln(os.Stderr, e) - case <-tick: - fmt.Print(".") - case <-doneCh: - return + l = len(resultChs) + for i := 0; i < l; i++ { + loop: + for { + select { + case r := <-resultChs[i]: + fmt.Print("\n") + print.NewTable(os.Stdout).Print(r) + break loop + case e := <-errChs[i]: + fmt.Print("\n") + fmt.Fprintln(os.Stderr, e) + break loop + case <-tick: + fmt.Print(".") + } } } } -func runQuery(client athenaiface.AthenaAPI, query string, resultCh chan *exec.Result, errCh chan error, wg *sync.WaitGroup) { +func runQuery(client athenaiface.AthenaAPI, query string, resultCh chan *exec.Result, errCh chan error) { // Run a query, and send results or an error r, err := exec.NewQuery(client, query, queryConfig).Run() if err != nil { @@ -126,5 +125,4 @@ func runQuery(client athenaiface.AthenaAPI, query string, resultCh chan *exec.Re } else { resultCh <- r } - wg.Done() }