Skip to content

Commit 5c951db

Browse files
committed
address PR comments
This removes the dedicated flagSet for streaming search and move parts of the logic to internal streaming.
1 parent fcf959b commit 5c951db

File tree

5 files changed

+73
-95
lines changed

5 files changed

+73
-95
lines changed

cmd/src/search.go

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717

1818
isatty "github.com/mattn/go-isatty"
1919
"github.com/sourcegraph/src-cli/internal/api"
20+
"github.com/sourcegraph/src-cli/internal/streaming"
2021
"jaytaylor.com/html2text"
2122
)
2223

@@ -50,14 +51,12 @@ Other tips:
5051

5152
flagSet := flag.NewFlagSet("search", flag.ExitOnError)
5253
var (
53-
jsonFlag = flagSet.Bool("json", false, "Whether or not to output results as JSON")
54+
jsonFlag = flagSet.Bool("json", false, "Whether or not to output results as JSON.")
5455
explainJSONFlag = flagSet.Bool("explain-json", false, "Explain the JSON output schema and exit.")
5556
apiFlags = api.NewFlags(flagSet)
56-
lessFlag = flagSet.Bool("less", true, "Pipe output to 'less -R' (only if stdout is terminal, and not json flag)")
57-
streamFlag = flagSet.Bool("stream", false, "Consume results as stream.")
58-
59-
// Streaming.
60-
_ = flagSet.Int("display", -1, "Limit the number of results shown. Only supported for streaming.")
57+
lessFlag = flagSet.Bool("less", true, "Pipe output to 'less -R' (only if stdout is terminal, and not json flag).")
58+
streamFlag = flagSet.Bool("stream", false, "Consume results as stream. Streaming search only supports a subset of flags and parameters: trace, insecure-skip-verify, display.")
59+
display = flagSet.Int("display", -1, "Limit the number of results that are displayed. Only supported together with stream flag. Statistics continue to report all results.")
6160
)
6261

6362
handler := func(args []string) error {
@@ -66,15 +65,12 @@ Other tips:
6665
}
6766

6867
if *streamFlag {
69-
// Remove -stream from args.
70-
argsWOStream := make([]string, 0, len(args)-1)
71-
for _, a := range args {
72-
if a == "-stream" {
73-
continue
74-
}
75-
argsWOStream = append(argsWOStream, a)
68+
opts := streaming.Opts{
69+
Display: *display,
70+
Trace: apiFlags.Trace(),
7671
}
77-
return streamHandler(argsWOStream)
72+
client := cfg.apiClient(apiFlags, flagSet.Output())
73+
return streamSearch(flagSet.Arg(0), opts, client, os.Stdout)
7874
}
7975

8076
if *explainJSONFlag {

cmd/src/search_stream.go

Lines changed: 3 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@ package main
22

33
import (
44
"bytes"
5-
"context"
6-
"flag"
75
"fmt"
86
"io"
9-
"net/url"
107
"os"
118
"os/exec"
129
"regexp"
@@ -23,51 +20,15 @@ func init() {
2320
labelRegexp, _ = regexp.Compile("(?:\\[)(.*?)(?:])")
2421
}
2522

26-
// streamHandler handles search requests which contain the flag "stream".
27-
// Requests are sent to search/stream instead of the GraphQL api.
28-
func streamHandler(args []string) error {
29-
flagSet := flag.NewFlagSet("streaming search", flag.ExitOnError)
30-
flags := newStreamingFlags(flagSet)
31-
if err := flagSet.Parse(args); err != nil {
32-
return err
33-
}
34-
35-
client := cfg.apiClient(flags.apiFlags, flagSet.Output())
36-
query := flagSet.Arg(0)
37-
return doStreamSearch(query, flags, client, os.Stdout)
38-
}
39-
40-
func doStreamSearch(query string, flags *streamingFlags, client api.Client, w io.Writer) error {
23+
func streamSearch(query string, opts streaming.Opts, client api.Client, w io.Writer) error {
4124
t, err := parseTemplate(streamingTemplate)
4225
if err != nil {
4326
panic(err)
4427
}
45-
46-
// Create request.
47-
req, err := client.NewHTTPRequest(context.Background(), "GET", "search/stream?q="+url.QueryEscape(query), nil)
48-
if err != nil {
49-
return err
50-
}
51-
req.Header.Set("Accept", "text/event-stream")
52-
if flags.display >= 0 {
53-
q := req.URL.Query()
54-
q.Add("display", strconv.Itoa(flags.display))
55-
req.URL.RawQuery = q.Encode()
56-
}
57-
58-
// Send request.
59-
resp, err := client.Do(req)
60-
if err != nil {
61-
return fmt.Errorf("error sending request: %w", err)
62-
}
63-
defer resp.Body.Close()
64-
6528
logError := func(msg string) {
6629
_, _ = fmt.Fprintf(os.Stderr, msg)
6730
}
68-
69-
// Process response.
70-
err = streaming.Decoder{
31+
decoder := streaming.Decoder{
7132
OnProgress: func(progress *streaming.Progress) {
7233
// We only show the final progress.
7334
if !progress.Done {
@@ -157,19 +118,9 @@ func doStreamSearch(query string, flags *streamingFlags, client api.Client, w io
157118
}
158119
}
159120
},
160-
}.ReadAll(resp.Body)
161-
if err != nil {
162-
return fmt.Errorf("error during decoding: %w", err)
163121
}
164122

165-
// Write trace to output.
166-
if flags.Trace() {
167-
_, err = fmt.Fprintf(os.Stderr, fmt.Sprintf("x-trace: %s\n", resp.Header.Get("x-trace")))
168-
if err != nil {
169-
return err
170-
}
171-
}
172-
return nil
123+
return streaming.Search(query, opts, client, decoder)
173124
}
174125

175126
const streamingTemplate = `
@@ -412,20 +363,3 @@ func streamConvertMatchToHighlights(m streaming.EventLineMatch, isPreview bool)
412363
}
413364
return highlights
414365
}
415-
416-
type streamingFlags struct {
417-
apiFlags *api.Flags
418-
display int
419-
}
420-
421-
func newStreamingFlags(flagSet *flag.FlagSet) *streamingFlags {
422-
flags := &streamingFlags{
423-
apiFlags: api.StreamingFlags(flagSet),
424-
}
425-
flagSet.IntVar(&flags.display, "display", -1, "Limit the number of results that are displayed. Note that the statistics continue to report all results.")
426-
return flags
427-
}
428-
429-
func (f *streamingFlags) Trace() bool {
430-
return f.apiFlags.Trace()
431-
}

cmd/src/search_stream_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/google/go-cmp/cmp"
1313

14+
"github.com/sourcegraph/src-cli/internal/api"
1415
"github.com/sourcegraph/src-cli/internal/streaming"
1516
)
1617

@@ -117,11 +118,10 @@ func TestSearchStream(t *testing.T) {
117118
t.Fatal(err)
118119
}
119120

120-
flagSet := flag.NewFlagSet("streaming search test", flag.ExitOnError)
121-
flags := newStreamingFlags(flagSet)
122-
client := cfg.apiClient(flags.apiFlags, flagSet.Output())
123-
124-
err = doStreamSearch("", flags, client, w)
121+
flagSet := flag.NewFlagSet("test", flag.ExitOnError)
122+
flags := api.NewFlags(flagSet)
123+
client := cfg.apiClient(flags, flagSet.Output())
124+
err = streamSearch("", streaming.Opts{}, client, w)
125125
if err != nil {
126126
t.Fatal(err)
127127
}

internal/api/flags.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,3 @@ func defaultFlags() *Flags {
3838
insecureSkipVerify: &d,
3939
}
4040
}
41-
42-
func StreamingFlags(flagSet *flag.FlagSet) *Flags {
43-
return &Flags{
44-
trace: flagSet.Bool("trace", false, "Log the trace ID for requests. See https://docs.sourcegraph.com/admin/observability/tracing"),
45-
insecureSkipVerify: flagSet.Bool("insecure-skip-verify", false, "Skip validation of TLS certificates against trusted chains"),
46-
}
47-
}

internal/streaming/search.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package streaming
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/url"
7+
"os"
8+
"strconv"
9+
10+
"github.com/sourcegraph/src-cli/internal/api"
11+
)
12+
13+
// Opts contains the search options supported by Search.
14+
type Opts struct {
15+
Display int
16+
Trace bool
17+
}
18+
19+
// Search calls the streaming search endpoint and uses decoder to decode the
20+
// response body.
21+
func Search(query string, opts Opts, client api.Client, decoder Decoder) error {
22+
// Create request.
23+
req, err := client.NewHTTPRequest(context.Background(), "GET", "search/stream?q="+url.QueryEscape(query), nil)
24+
if err != nil {
25+
return err
26+
}
27+
req.Header.Set("Accept", "text/event-stream")
28+
if opts.Display >= 0 {
29+
q := req.URL.Query()
30+
q.Add("display", strconv.Itoa(opts.Display))
31+
req.URL.RawQuery = q.Encode()
32+
}
33+
34+
// Send request.
35+
resp, err := client.Do(req)
36+
if err != nil {
37+
return fmt.Errorf("error sending request: %w", err)
38+
}
39+
defer resp.Body.Close()
40+
41+
// Process response.
42+
err = decoder.ReadAll(resp.Body)
43+
if err != nil {
44+
return fmt.Errorf("error during decoding: %w", err)
45+
}
46+
47+
// Output trace.
48+
if opts.Trace {
49+
_, err = fmt.Fprintf(os.Stderr, fmt.Sprintf("x-trace: %s\n", resp.Header.Get("x-trace")))
50+
if err != nil {
51+
return err
52+
}
53+
}
54+
return nil
55+
}

0 commit comments

Comments
 (0)