From 0a87478b8e1c02aff5a3d1be1b93ce56b3f96535 Mon Sep 17 00:00:00 2001 From: Taylor Bartlett Date: Fri, 30 Oct 2020 18:46:45 -0600 Subject: [PATCH 1/3] Inital pass at run time variables --- builder.go | 24 +++- cmd/rest/main.go | 7 ++ go.mod | 4 +- go.sum | 4 +- lex.go | 261 +++++++++++++++++++++++++++----------------- rest.go | 96 +++++++++++----- rest_test.go | 36 ++++++ synthesizer.go | 21 ++-- synthesizer_test.go | 23 ++-- test/client.rest | 12 +- test/runtime.rest | 15 +++ 11 files changed, 350 insertions(+), 153 deletions(-) create mode 100644 test/runtime.rest diff --git a/builder.go b/builder.go index ef413ae..6eb150c 100644 --- a/builder.go +++ b/builder.go @@ -10,16 +10,38 @@ import ( "os" "path/filepath" "strings" + "time" + + "github.com/taybart/log" ) +type request struct { + label string + skip bool + r *http.Request + delay time.Duration + expectation expectation + outputs map[string]string +} + // type builder struct{} // buildRequest : generate http.Request from parsed input -func buildRequest(input metaRequest) (req request, err error) { +func buildRequest(input metaRequest, variables map[string]restVar) (req request, err error) { if err = isValidMetaRequest(input); err != nil { return } + if input.reinterpret { + log.Debug("Re-interpreting request", variables) + l := newLexer(false) + l.variables = variables + input, err = l.parseBlock(input.block) + if err != nil { + return + } + } + var r *http.Request url := fmt.Sprintf("%s%s", input.url, input.path) if !input.skip { // don't validate if skipped diff --git a/cmd/rest/main.go b/cmd/rest/main.go index 81f95c0..ef62ec5 100644 --- a/cmd/rest/main.go +++ b/cmd/rest/main.go @@ -55,15 +55,18 @@ func help() { func main() { flag.Parse() + log.SetPlain() log.SetLevel(log.WARN) if verbose { log.SetLevel(log.DEBUG) } + if servelog || servedir { serve(servedir, local, port) return } + r := rest.New() if nocolor { r.NoColor() @@ -107,6 +110,7 @@ func main() { func readFiles(r *rest.Rest) { for _, f := range fns { + log.Debug("Reading file %s...", f) if fileExists(f) { valid, err := r.IsRestFile(f) if !valid { @@ -116,12 +120,15 @@ func readFiles(r *rest.Rest) { err = r.Read(f) if err != nil { log.Error("Read error", err) + continue } + log.Debug("done\n") } } } func exec(r *rest.Rest) { + log.Debug("\nExecuting all requests\n") if index >= 0 { res, err := r.ExecIndex(index) if err != nil { diff --git a/go.mod b/go.mod index 2690922..25e1338 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/taybart/rest -go 1.13 +go 1.14 require ( github.com/matryer/is v1.2.0 - github.com/taybart/log v1.1.1 + github.com/taybart/log v1.2.2 ) diff --git a/go.sum b/go.sum index b761559..9ce4f66 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A= github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA= -github.com/taybart/log v1.1.1 h1:cuYtzjywA8D1rxmmyQ7ry5fwduAcktmxr7/WZxAf1Xw= -github.com/taybart/log v1.1.1/go.mod h1:bLpgJt6GrTPJNabumQPcFTaYJnGwj7mSMs/OQwTDAdE= +github.com/taybart/log v1.2.2 h1:jlg3dibUsaanJLzofXzRmBMD+JHHPqbXiYpXdgXTYa8= +github.com/taybart/log v1.2.2/go.mod h1:e9MmKdjMsNxSFbn46ag778NSmhFrx9yLC60CDRRcSOo= diff --git a/lex.go b/lex.go index 405623b..1921786 100644 --- a/lex.go +++ b/lex.go @@ -3,7 +3,6 @@ package rest import ( "bufio" "fmt" - "net/http" "regexp" "strconv" "strings" @@ -19,81 +18,75 @@ const ( stateBody ) +var ( + rxLabel = regexp.MustCompile(`^label (.*)`) + rxSkip = regexp.MustCompile(`^skip\s*$`) + rxDelay = regexp.MustCompile(`^delay (\d+(ns|us|µs|ms|s|m|h))$`) + rxVarDefinition = regexp.MustCompile(`^set ([[:word:]\-]+) (.+)`) + rxURL = regexp.MustCompile(`^(https?)://[^\s/$.?#]*[^\s]*$`) + rxHeader = regexp.MustCompile(`[a-zA-Z-]+: .+`) + rxMethod = regexp.MustCompile(`^(OPTIONS|GET|POST|PUT|DELETE)`) + rxPath = regexp.MustCompile(`\/.*`) + rxFile = regexp.MustCompile(`^file://([/a-zA-Z0-9\-_\.]+)[\s+]?([a-zA-Z0-9]+)?$`) + rxVar = regexp.MustCompile(`\$\{([[:word:]\-]+)\}`) + rxExpect = regexp.MustCompile(`^expect (\d+) ?(.*)`) + rxComment = regexp.MustCompile(`^[[:space:]]*[#|\/\/]`) + rxRuntimeVar = regexp.MustCompile(`^take ([[:word:]]+) as ([[:word:]\-]+)`) +) + +type restVar struct { + name string + value string + runtime bool +} + type expectation struct { code int body string } - type metaRequest struct { - label string - skip bool - url string - headers map[string]string - method string - path string - body string - filepath string - filelabel string - delay time.Duration - expectation expectation + label string + skip bool + url string + headers map[string]string + method string + path string + body string + filepath string + filelabel string + delay time.Duration + expectation expectation + reinterpret bool + reinterpretVars []restVar + block []string } - -type request struct { - label string - skip bool - r *http.Request - delay time.Duration - expectation expectation +type requestBatch struct { + requests []metaRequest + rtVars map[string]restVar } type lexer struct { - rxLabel *regexp.Regexp - rxSkip *regexp.Regexp - rxDelay *regexp.Regexp - rxVarDefinition *regexp.Regexp - rxURL *regexp.Regexp - rxHeader *regexp.Regexp - rxPath *regexp.Regexp - rxMethod *regexp.Regexp - rxFile *regexp.Regexp - rxVar *regexp.Regexp - rxExpect *regexp.Regexp - rxComment *regexp.Regexp - - variables map[string]string + variables map[string]restVar concurrent bool - bch chan request + bch chan metaRequest } func newLexer(concurrent bool) lexer { return lexer{ - rxLabel: regexp.MustCompile(`^label (.*)`), - rxSkip: regexp.MustCompile(`^skip\s*$`), - rxDelay: regexp.MustCompile(`^delay (\d+(ns|us|µs|ms|s|m|h))$`), - rxVarDefinition: regexp.MustCompile(`^set ([[:word:]\-]+) (.+)`), - rxURL: regexp.MustCompile(`^(https?)://[^\s/$.?#]*[^\s]*$`), - rxHeader: regexp.MustCompile(`[a-zA-Z-]+: .+`), - rxMethod: regexp.MustCompile(`^(OPTIONS|GET|POST|PUT|DELETE)`), - rxPath: regexp.MustCompile(`\/.*`), - rxFile: regexp.MustCompile(`^file://([/a-zA-Z0-9\-_\.]+)[\s+]?([a-zA-Z0-9]+)?$`), - rxVar: regexp.MustCompile(`\$\{([[:word:]\-]+)\}`), - rxExpect: regexp.MustCompile(`^expect (\d+) ?(.*)`), - rxComment: regexp.MustCompile(`^[[:space:]]*[#|\/\/]`), - - variables: make(map[string]string), + variables: make(map[string]restVar), concurrent: concurrent, - bch: make(chan request), + bch: make(chan metaRequest), } } // parse : Parse a rest file and build golang http requests from it -func (l *lexer) parse(scanner *bufio.Scanner) ([]request, error) { - log.Debug("Lex starting parse") +func (l *lexer) parse(scanner *bufio.Scanner) (requests requestBatch, err error) { + log.Debug("\nLex starting parse...") blocks := [][]string{} block := []string{} for scanner.Scan() { line := scanner.Text() - if line == "---" { + if line == "---" { // next block blocks = append(blocks, block) block = []string{} continue @@ -104,65 +97,119 @@ func (l *lexer) parse(scanner *bufio.Scanner) ([]request, error) { blocks = append(blocks, block) log.Debugf("Got %d blocks\n", len(blocks)) + p, err := l.firstPass(blocks) + if err != nil { + return + } + rtVars := make(map[string]restVar) + for k, v := range l.variables { + if v.runtime { + log.Debugf("var: %s is runtime\n", k) + rtVars[k] = v + } + } + + var rs []metaRequest if l.concurrent { - return l.parseBlocksConcurrently(blocks) + rs, err = l.parseConcurrent(p) + } else { + rs, err = l.parseSerial(p) + } + if err != nil { + return + } + return requestBatch{ + requests: rs, + rtVars: rtVars, + }, nil +} + +func (l *lexer) firstPass(blocks [][]string) (meta []metaRequest, err error) { + for i, b := range blocks { + for _, ln := range b { + switch { + case rxSkip.MatchString(ln): + continue + case rxRuntimeVar.MatchString(ln): + if l.concurrent { + err = fmt.Errorf("found runtime variable but rest is set to run concurrently") + return + } + v := rxRuntimeVar.FindStringSubmatch(ln) + log.Debugf("Found runtime variable %s with return value of %s\n", v[2], v[1]) + l.variables[v[2]] = restVar{ + name: v[2], + value: v[1], + runtime: true, + } + } + } + log.Debug("First pass on block", i) + meta = append(meta, metaRequest{ + block: b, + }) } - return l.parseBlocks(blocks) + return } // parseBlocks : Parse blocks in the order in which they were given -func (l *lexer) parseBlocks(blocks [][]string) (reqs []request, err error) { +func (l *lexer) parseSerial(input []metaRequest) (reqs []metaRequest, err error) { log.Debug("Starting to parse blocks in order") - for i, block := range blocks { - r, e := l.parseBlock(block) + for i, r := range input { + lexed, e := l.parseBlock(r.block) if e != nil { err = fmt.Errorf("block %d: %w", i, e) // log.Error(e) continue // TODO maybe should super fail } - reqs = append(reqs, r) + reqs = append(reqs, lexed) } log.Debugf("Parsed %d blocks\n", len(reqs)) - l.variables = make(map[string]string) // purge vars + l.purgeVars() return } // parseBlocksConcurrently : Parse all blocks but don't care about order -func (l *lexer) parseBlocksConcurrently(blocks [][]string) (reqs []request, err error) { +func (l *lexer) parseConcurrent(input []metaRequest) (reqs []metaRequest, err error) { log.Debug("Starting to parse blocks concurrently") - for _, block := range blocks { - go l.parseBlock(block) + for _, r := range input { + go l.parseBlock(r.block) } - for i := 0; i < len(blocks); i++ { + for i := 0; i < len(input); i++ { r := <-l.bch reqs = append(reqs, r) } log.Debug("Done") - l.variables = make(map[string]string) // purge vars + l.purgeVars() return } // parseBlock : Get all parts of request from request block -func (l *lexer) parseBlock(block []string) (request, error) { +func (l *lexer) parseBlock(block []string) (metaRequest, error) { req := metaRequest{ headers: make(map[string]string), } state := stateUrl for i, ln := range block { - if l.rxComment.MatchString(ln) { + if rxComment.MatchString(ln) { log.Debug("Get comment", ln) continue } - line, err := l.checkForVariables(ln) + line, runtime, err := l.checkForUndeclaredVariables(ln) if err != nil { log.Fatal(err) } + if runtime { + req.block = block + req.reinterpret = true + continue + } switch { - case l.rxSkip.MatchString(line): + case rxSkip.MatchString(line): req.skip = true - case l.rxExpect.MatchString(line): - m := l.rxExpect.FindStringSubmatch(line) + case rxExpect.MatchString(line): + m := rxExpect.FindStringSubmatch(line) if len(m) == 1 { log.Errorf("Malformed expectation in block %d [%s]\n", i, line) continue @@ -175,52 +222,55 @@ func (l *lexer) parseBlock(block []string) (request, error) { if len(m) == 3 { req.expectation.body = m[2] } - case l.rxDelay.MatchString(line): - m := l.rxDelay.FindStringSubmatch(line) + case rxDelay.MatchString(line): + m := rxDelay.FindStringSubmatch(line) req.delay, err = time.ParseDuration(m[1]) if err != nil { log.Errorf("Cannot parse delay in block %d [%s]\n", i, line) continue } - case l.rxVarDefinition.MatchString(line): - v := l.rxVarDefinition.FindStringSubmatch(line) + case rxVarDefinition.MatchString(line): + v := rxVarDefinition.FindStringSubmatch(line) log.Debugf("Setting %s to %s\n", string(v[1]), string(v[2])) - l.variables[v[1]] = v[2] - case l.rxURL.MatchString(line): - u := l.rxURL.FindString(line) + l.variables[v[1]] = restVar{ + name: v[1], + value: v[2], + } + case rxURL.MatchString(line): + u := rxURL.FindString(line) if isUrl(u) { req.url = u log.Debug("Got URL", u) } state = stateHeaders - case l.rxMethod.MatchString(line): - m := l.rxMethod.FindString(line) + case rxMethod.MatchString(line): + m := rxMethod.FindString(line) req.method = m - p := l.rxPath.FindString(line) + p := rxPath.FindString(line) req.path = p log.Debug("Got method", m) log.Debug("Got path", p) state = stateBody - case l.rxHeader.MatchString(line) && state == stateHeaders: + case rxHeader.MatchString(line) && state == stateHeaders: sp := strings.Split(line, ":") key := strings.TrimSpace(sp[0]) value := strings.TrimSpace(sp[1]) req.headers[key] = value log.Debugf("Set header %s to %s\n", key, value) - case l.rxFile.MatchString(line): - // fn := l.rxFile.FindString(line) - matches := l.rxFile.FindStringSubmatch(line) + case rxFile.MatchString(line): + // fn := rxFile.FindString(line) + matches := rxFile.FindStringSubmatch(line) if isValidFile(matches[1]) { req.filepath = matches[1] req.filelabel = matches[2] log.Debug("Got File", req.filepath, req.filelabel) } state = stateHeaders - case l.rxLabel.MatchString(line): - m := l.rxLabel.FindStringSubmatch(line) + case rxLabel.MatchString(line): + m := rxLabel.FindStringSubmatch(line) req.label = m[1] case state == stateBody: @@ -228,30 +278,37 @@ func (l *lexer) parseBlock(block []string) (request, error) { } } log.Debug("Building request") - r, err := buildRequest(req) - if err != nil { - return request{}, err - } if l.concurrent { - l.bch <- r + l.bch <- req } - return r, nil + return req, nil } -func (l lexer) checkForVariables(line string) (string, error) { +func (l lexer) checkForUndeclaredVariables(line string) (string, bool, error) { tmp := line - if l.rxVar.MatchString(line) { - matches := l.rxVar.FindAllStringSubmatch(line, -1) + reinterpret := false + if rxVar.MatchString(line) { + matches := rxVar.FindAllStringSubmatch(line, -1) for _, match := range matches { - if value, ok := l.variables[match[1]]; ok { - tmp = strings.ReplaceAll(tmp, match[0], value) - } else { - return "", fmt.Errorf("Saw variable %s%s%s and did not have a value for it", - log.Blue, match[1], log.Rtd) + if l.variables[match[1]].runtime { + tmp = l.variables[match[1]].value + reinterpret = true + log.Debug(line, "-> NEED RUNTIME VALUE") + continue + } + if v, ok := l.variables[match[1]]; ok { + tmp = strings.ReplaceAll(tmp, match[0], v.value) + return tmp, false, nil } + return "", false, fmt.Errorf("Saw variable %s%s%s and did not have a value for it", + log.Blue, match[1], log.Rtd) log.Debug(line, "->", tmp) } } - return tmp, nil + return tmp, reinterpret, nil +} + +func (l *lexer) purgeVars() { + l.variables = make(map[string]restVar) } diff --git a/rest.go b/rest.go index e452352..b550a84 100644 --- a/rest.go +++ b/rest.go @@ -3,6 +3,7 @@ package rest import ( "bufio" "bytes" + "encoding/json" "fmt" "io" "io/ioutil" @@ -18,6 +19,7 @@ import ( type Rest struct { color bool client *http.Client + lexed requestBatch requests []request } @@ -26,6 +28,9 @@ func New() *Rest { return &Rest{ color: true, client: http.DefaultClient, + lexed: requestBatch{ + rtVars: make(map[string]restVar), + }, } } @@ -42,15 +47,7 @@ func (r *Rest) SetClient(c *http.Client) { // ReadIO : read ordered requests from io reader func (r *Rest) ReadIO(buf io.Reader) error { scanner := bufio.NewScanner(buf) - lex := newLexer( - false, // concurrent - ) - reqs, err := lex.parse(scanner) - if err != nil { - return err - } - r.requests = append(r.requests, reqs...) - return nil + return r.read(scanner, false) } // Read : read ordered requests from file @@ -62,15 +59,7 @@ func (r *Rest) Read(fn string) error { defer file.Close() scanner := bufio.NewScanner(file) - lex := newLexer( - false, // concurrent - ) - reqs, err := lex.parse(scanner) - if err != nil { - return err - } - r.requests = append(r.requests, reqs...) - return nil + return r.read(scanner, false) } // ReadConcurrent : read unordered requests from file @@ -82,24 +71,34 @@ func (r *Rest) ReadConcurrent(fn string) error { defer file.Close() scanner := bufio.NewScanner(file) - lex := newLexer( - true, // concurrent - ) + return r.read(scanner, true) +} + +func (r *Rest) read(scanner *bufio.Scanner, concurrent bool) error { + lex := newLexer(concurrent) reqs, err := lex.parse(scanner) if err != nil { return err } - r.requests = append(r.requests, reqs...) + r.lexed.requests = append(r.lexed.requests, reqs.requests...) + for k, v := range reqs.rtVars { + r.lexed.rtVars[k] = v + } return nil } // Exec : do all loaded requests -func (r *Rest) Exec() (successful []string, err error) { - // TODO create error report - for i, req := range r.requests { +func (r Rest) Exec() (successful []string, err error) { + for i, l := range r.lexed.requests { + var req request + req, err = buildRequest(l, r.lexed.rtVars) + if err != nil { + return + } if req.skip { continue } + time.Sleep(req.delay) log.Debugf("Sending request %d to %s\n", i, req.r.URL.String()) var resp *http.Response @@ -117,6 +116,11 @@ func (r *Rest) Exec() (successful []string, err error) { return } + err = r.takeVariables(resp, &r.lexed.rtVars) + if err != nil { + return + } + var dump []byte dump, err = httputil.DumpResponse(resp, true) if err != nil { @@ -152,12 +156,16 @@ func (r *Rest) Exec() (successful []string, err error) { } // ExecIndex : do specific block in requests -func (r *Rest) ExecIndex(i int) (result string, err error) { +func (r Rest) ExecIndex(i int) (result string, err error) { if i > len(r.requests)-1 { err = fmt.Errorf("Block %d does not exist", i) return } - req := r.requests[i] + + req, err := buildRequest(r.lexed.requests[i], map[string]restVar{}) + if err != nil { + return + } time.Sleep(req.delay) log.Debugf("Sending request %d to %s\n", i, req.r.URL.String()) resp, err := r.client.Do(req.r) @@ -195,7 +203,7 @@ func (r *Rest) ExecIndex(i int) (result string, err error) { } // CheckExpectation : ensure request did what is was supposed to -func (r *Rest) CheckExpectation(req request, res *http.Response) error { +func (r Rest) CheckExpectation(req request, res *http.Response) error { exp := req.expectation if exp.code == 0 { return nil @@ -222,7 +230,8 @@ func (r *Rest) CheckExpectation(req request, res *http.Response) error { } // IsRestFile : checks if file can be parsed -func (r *Rest) IsRestFile(fn string) (bool, error) { +func (r Rest) IsRestFile(fn string) (bool, error) { + log.Debugf("Checking if %s is a valid rest file\n", fn) file, err := os.Open(fn) if err != nil { return false, err @@ -237,5 +246,34 @@ func (r *Rest) IsRestFile(fn string) (bool, error) { if err != nil { return false, fmt.Errorf("Invalid format or malformed file: %w", err) } + log.Debugf("Yay! %s is valid!\n", fn) return true, nil } + +func (rest Rest) takeVariables(res *http.Response, rtVars *map[string]restVar) (err error) { + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return + } + if len(body) == 0 { + return + } + var j map[string]string + err = json.Unmarshal(body, &j) + if err != nil { + return + } + for k, v := range *rtVars { + for jk, jv := range j { + if v.value == jk { + (*rtVars)[k] = restVar{ + name: k, + value: jv, + } + } + } + } + + return +} diff --git a/rest_test.go b/rest_test.go index 0a572f8..f924ebb 100644 --- a/rest_test.go +++ b/rest_test.go @@ -206,3 +206,39 @@ func TestExpect(t *testing.T) { _, err = r.Exec() is.NoErr(err) } + +func TestRuntimeVariables(t *testing.T) { + is := is.New(t) + r := New() + err := r.Read("./test/runtime.rest") + is.NoErr(err) + client := NewTestClient(func(r *http.Request) *http.Response { + switch r.URL.Path { + case "/login": + // Test request parameters + return &http.Response{ + StatusCode: 200, + // Send response to be tested + Body: ioutil.NopCloser(bytes.NewBufferString(`{"auth_token": "test"}`)), + // Must be set to non-nil value or it panics + Header: make(http.Header), + } + case "/account": + if r.Header.Get("Authorization") != "Bearer test" { + t.Fatal("auth_token was not present during second call") + } + return &http.Response{ + StatusCode: 200, + // Must be set to non-nil value or it panics + Header: make(http.Header), + } + default: + t.Fatal("Unknown url called") + return nil + } + }) + + r.SetClient(client) + _, err = r.Exec() + is.NoErr(err) +} diff --git a/synthesizer.go b/synthesizer.go index 2e60805..0e3edd9 100644 --- a/synthesizer.go +++ b/synthesizer.go @@ -47,11 +47,18 @@ func (r Rest) SynthesizeClient(lang string) (string, error) { // SynthisizeRequests : output request code func (r Rest) SynthesizeRequests(lang string) ([]string, error) { if t := templates.Get(lang); t != nil { - requests := make([]string, len(r.requests)) - for i, req := range r.requests { - body, err := ioutil.ReadAll(req.r.Body) - if err != nil { - log.Error(err) + requests := []string{} + for _, req := range r.requests { + if req.skip { + continue + } + var body []byte + if req.r.Body != nil { + var err error + body, err = ioutil.ReadAll(req.r.Body) + if err != nil { + log.Error(err) + } } templReq := struct { URL string @@ -66,11 +73,11 @@ func (r Rest) SynthesizeRequests(lang string) ([]string, error) { } var buf bytes.Buffer - err = t.Request.Execute(&buf, templReq) + err := t.Request.Execute(&buf, templReq) if err != nil { log.Error(err) } - requests[i] = buf.String() + requests = append(requests, buf.String()) } return requests, nil } diff --git a/synthesizer_test.go b/synthesizer_test.go index 08b1f25..ca1c2ef 100644 --- a/synthesizer_test.go +++ b/synthesizer_test.go @@ -10,20 +10,29 @@ import ( func TestSynthesizeRequests(t *testing.T) { is := is.New(t) + + // Read example request r := New() err := r.Read("./test/post.rest") is.NoErr(err) + table := []struct { lang string ft string }{{"javascript", "js"}, {"go", "go"}, {"curl", "curl"}} + for _, tt := range table { tt := tt t.Run(tt.lang, func(t *testing.T) { - requests, err := r.SynthisizeRequests(tt.lang) + // Gen requests + requests, err := r.SynthesizeRequests(tt.lang) is.NoErr(err) + + // Get answer ans, err := ioutil.ReadFile(fmt.Sprintf("./test/template_request.%s", tt.ft)) is.NoErr(err) + + // Check answer for i, c := range requests[0] { is.Equal(rune(ans[i]), c) } @@ -33,21 +42,21 @@ func TestSynthesizeRequests(t *testing.T) { func TestSynthesizeClient(t *testing.T) { is := is.New(t) + + // Get all requests r := New() err := r.Read("./test/client.rest") is.NoErr(err) + table := []struct { lang string ft string - }{ - {"javascript", "js"}, - {"go", "go"}, - {"curl", "curl"}, - } + }{{"javascript", "js"}, {"go", "go"}, {"curl", "curl"}} + for _, tt := range table { tt := tt t.Run(tt.lang, func(t *testing.T) { - _, err := r.SynthisizeClient(tt.lang) + _, err := r.SynthesizeClient(tt.lang) is.NoErr(err) }) } diff --git a/test/client.rest b/test/client.rest index bd0d11f..ba07d48 100644 --- a/test/client.rest +++ b/test/client.rest @@ -7,15 +7,21 @@ GET / set URL http://localhost:8080 -skip - delay 5s label PostThing ${URL} -POST /user Content-Type: application/json +POST /user { "user": "taybart", "11": 12, } + +--- + +skip + +label SkippedThing +${URL} +GET /user diff --git a/test/runtime.rest b/test/runtime.rest new file mode 100644 index 0000000..0bf07ca --- /dev/null +++ b/test/runtime.rest @@ -0,0 +1,15 @@ + +http://localhost:8080 +POST /login +{ + "username": "test", + "password": "password" +} + +take auth_token as AUTH_TOKEN + +--- + +http://localhost:8080 +Authorization: Bearer ${AUTH_TOKEN} +GET /account From 7a34f17b4811fccdaff76fb342c6aa97a3594590 Mon Sep 17 00:00:00 2001 From: Taylor Bartlett Date: Mon, 2 Nov 2020 07:54:25 -0700 Subject: [PATCH 2/3] Fix...well idk how that was working --- builder.go | 8 ++++---- lex.go | 21 +++++++++++---------- lex_test.go | 1 + rest.go | 13 ++++++++++--- rest_test.go | 19 +++++++++++++++++++ 5 files changed, 45 insertions(+), 17 deletions(-) diff --git a/builder.go b/builder.go index 6eb150c..1add16b 100644 --- a/builder.go +++ b/builder.go @@ -28,10 +28,6 @@ type request struct { // buildRequest : generate http.Request from parsed input func buildRequest(input metaRequest, variables map[string]restVar) (req request, err error) { - if err = isValidMetaRequest(input); err != nil { - return - } - if input.reinterpret { log.Debug("Re-interpreting request", variables) l := newLexer(false) @@ -42,6 +38,10 @@ func buildRequest(input metaRequest, variables map[string]restVar) (req request, } } + if err = isValidMetaRequest(input); err != nil { + return + } + var r *http.Request url := fmt.Sprintf("%s%s", input.url, input.path) if !input.skip { // don't validate if skipped diff --git a/lex.go b/lex.go index 1921786..e63a3e8 100644 --- a/lex.go +++ b/lex.go @@ -101,13 +101,6 @@ func (l *lexer) parse(scanner *bufio.Scanner) (requests requestBatch, err error) if err != nil { return } - rtVars := make(map[string]restVar) - for k, v := range l.variables { - if v.runtime { - log.Debugf("var: %s is runtime\n", k) - rtVars[k] = v - } - } var rs []metaRequest if l.concurrent { @@ -118,6 +111,14 @@ func (l *lexer) parse(scanner *bufio.Scanner) (requests requestBatch, err error) if err != nil { return } + rtVars := make(map[string]restVar) + for k, v := range l.variables { + if v.runtime { + log.Debugf("var: %s is runtime\n", k) + } + rtVars[k] = v + } + l.purgeVars() return requestBatch{ requests: rs, rtVars: rtVars, @@ -165,7 +166,6 @@ func (l *lexer) parseSerial(input []metaRequest) (reqs []metaRequest, err error) reqs = append(reqs, lexed) } log.Debugf("Parsed %d blocks\n", len(reqs)) - l.purgeVars() return } @@ -181,7 +181,6 @@ func (l *lexer) parseConcurrent(input []metaRequest) (reqs []metaRequest, err er reqs = append(reqs, r) } log.Debug("Done") - l.purgeVars() return } @@ -208,6 +207,8 @@ func (l *lexer) parseBlock(block []string) (metaRequest, error) { switch { case rxSkip.MatchString(line): req.skip = true + case rxRuntimeVar.MatchString(ln): + continue case rxExpect.MatchString(line): m := rxExpect.FindStringSubmatch(line) if len(m) == 1 { @@ -301,9 +302,9 @@ func (l lexer) checkForUndeclaredVariables(line string) (string, bool, error) { tmp = strings.ReplaceAll(tmp, match[0], v.value) return tmp, false, nil } + log.Debug(line, "->", tmp) return "", false, fmt.Errorf("Saw variable %s%s%s and did not have a value for it", log.Blue, match[1], log.Rtd) - log.Debug(line, "->", tmp) } } return tmp, reinterpret, nil diff --git a/lex_test.go b/lex_test.go index 097c0a4..707b987 100644 --- a/lex_test.go +++ b/lex_test.go @@ -35,6 +35,7 @@ func TestLexFiles(t *testing.T) { {name: "delay", fn: "./test/delay.rest", res: true}, {name: "expect", fn: "./test/expect.rest", res: true}, {name: "skip", fn: "./test/skip.rest", res: true}, + {name: "runtime", fn: "./test/runtime.rest", res: true}, {name: "invalid", fn: "./test/invalid.rest", res: false}, // TODO add individual failures } for _, tt := range files { diff --git a/rest.go b/rest.go index b550a84..04be5a8 100644 --- a/rest.go +++ b/rest.go @@ -90,6 +90,7 @@ func (r *Rest) read(scanner *bufio.Scanner, concurrent bool) error { // Exec : do all loaded requests func (r Rest) Exec() (successful []string, err error) { for i, l := range r.lexed.requests { + log.Debug("Building request block", i) var req request req, err = buildRequest(l, r.lexed.rtVars) if err != nil { @@ -109,6 +110,7 @@ func (r Rest) Exec() (successful []string, err error) { // continue } + log.Debug("Checking expectation") err = r.CheckExpectation(req, resp) if err != nil { // failed = append(failed, err.Error()) @@ -116,6 +118,7 @@ func (r Rest) Exec() (successful []string, err error) { return } + log.Debug("Take output into runtime bars") err = r.takeVariables(resp, &r.lexed.rtVars) if err != nil { return @@ -251,17 +254,21 @@ func (r Rest) IsRestFile(fn string) (bool, error) { } func (rest Rest) takeVariables(res *http.Response, rtVars *map[string]restVar) (err error) { - defer res.Body.Close() body, err := ioutil.ReadAll(res.Body) if err != nil { return } + res.Body.Close() if len(body) == 0 { return } - var j map[string]string + + res.Body = ioutil.NopCloser(bytes.NewBuffer(body)) // put body back + + var j map[string]interface{} err = json.Unmarshal(body, &j) if err != nil { + err = fmt.Errorf("could not take variables from request %w", err) return } for k, v := range *rtVars { @@ -269,7 +276,7 @@ func (rest Rest) takeVariables(res *http.Response, rtVars *map[string]restVar) ( if v.value == jk { (*rtVars)[k] = restVar{ name: k, - value: jv, + value: jv.(string), } } } diff --git a/rest_test.go b/rest_test.go index f924ebb..ea0754c 100644 --- a/rest_test.go +++ b/rest_test.go @@ -2,6 +2,7 @@ package rest import ( "bytes" + "encoding/json" "io/ioutil" "net/http" "testing" @@ -212,9 +213,22 @@ func TestRuntimeVariables(t *testing.T) { r := New() err := r.Read("./test/runtime.rest") is.NoErr(err) + + loginCalled := false + accountCalled := false client := NewTestClient(func(r *http.Request) *http.Response { switch r.URL.Path { case "/login": + t.Log("login") + var j map[string]string + err = json.NewDecoder(r.Body).Decode(&j) + is.NoErr(err) + + is.Equal(j["username"], "test") + is.Equal(j["password"], "password") + + loginCalled = true + // Test request parameters return &http.Response{ StatusCode: 200, @@ -224,9 +238,11 @@ func TestRuntimeVariables(t *testing.T) { Header: make(http.Header), } case "/account": + t.Log("account") if r.Header.Get("Authorization") != "Bearer test" { t.Fatal("auth_token was not present during second call") } + accountCalled = true return &http.Response{ StatusCode: 200, // Must be set to non-nil value or it panics @@ -241,4 +257,7 @@ func TestRuntimeVariables(t *testing.T) { r.SetClient(client) _, err = r.Exec() is.NoErr(err) + + is.True(loginCalled) + is.True(accountCalled) } From 6cce87ffb6e99c89a8c74b775751a1cc6aa7be41 Mon Sep 17 00:00:00 2001 From: Taylor Bartlett Date: Mon, 2 Nov 2020 07:57:06 -0700 Subject: [PATCH 3/3] Bump go to 1.15 --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 25e1338..23093d9 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/taybart/rest -go 1.14 +go 1.15 require ( github.com/matryer/is v1.2.0