diff --git a/frankenphp_test.go b/frankenphp_test.go index 24241e058b..4c36e0c38f 100644 --- a/frankenphp_test.go +++ b/frankenphp_test.go @@ -107,19 +107,40 @@ func runTest(t *testing.T, test func(func(http.ResponseWriter, *http.Request), * wg.Wait() } +func testRequest(req *http.Request, handler func(http.ResponseWriter, *http.Request), t *testing.T) (string, *http.Response) { + t.Helper() + w := httptest.NewRecorder() + handler(w, req) + resp := w.Result() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + return string(body), resp +} + +func testGet(url string, handler func(http.ResponseWriter, *http.Request), t *testing.T) (string, *http.Response) { + t.Helper() + req := httptest.NewRequest(http.MethodGet, url, nil) + + return testRequest(req, handler, t) +} + +func testPost(url string, body string, handler func(http.ResponseWriter, *http.Request), t *testing.T) (string, *http.Response) { + t.Helper() + req := httptest.NewRequest(http.MethodPost, url, nil) + req.Body = io.NopCloser(strings.NewReader(body)) + + return testRequest(req, handler, t) +} + func TestHelloWorld_module(t *testing.T) { testHelloWorld(t, nil) } func TestHelloWorld_worker(t *testing.T) { testHelloWorld(t, &testOptions{workerScript: "index.php"}) } func testHelloWorld(t *testing.T, opts *testOptions) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/index.php?i=%d", i), nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - assert.Equal(t, fmt.Sprintf("I am by birth a Genevese (%d)", i), string(body)) + body, _ := testGet(fmt.Sprintf("http://example.com/index.php?i=%d", i), handler, t) + assert.Equal(t, fmt.Sprintf("I am by birth a Genevese (%d)", i), body) }, opts) } @@ -129,13 +150,8 @@ func TestFinishRequest_worker(t *testing.T) { } func testFinishRequest(t *testing.T, opts *testOptions) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/finish-request.php?i=%d", i), nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - assert.Equal(t, fmt.Sprintf("This is output %d\n", i), string(body)) + body, _ := testGet(fmt.Sprintf("http://example.com/finish-request.php?i=%d", i), handler, t) + assert.Equal(t, fmt.Sprintf("This is output %d\n", i), body) }, opts) } @@ -150,39 +166,33 @@ func testServerVariable(t *testing.T, opts *testOptions) { req := httptest.NewRequest("POST", fmt.Sprintf("http://example.com/server-variable.php/baz/bat?foo=a&bar=b&i=%d#hash", i), strings.NewReader("foo")) req.SetBasicAuth(strings.Clone("kevin"), strings.Clone("password")) req.Header.Add(strings.Clone("Content-Type"), strings.Clone("text/plain")) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - strBody := string(body) - - assert.Contains(t, strBody, "[REMOTE_HOST]") - assert.Contains(t, strBody, "[REMOTE_USER] => kevin") - assert.Contains(t, strBody, "[PHP_AUTH_USER] => kevin") - assert.Contains(t, strBody, "[PHP_AUTH_PW] => password") - assert.Contains(t, strBody, "[HTTP_AUTHORIZATION] => Basic a2V2aW46cGFzc3dvcmQ=") - assert.Contains(t, strBody, "[DOCUMENT_ROOT]") - assert.Contains(t, strBody, "[PHP_SELF] => /server-variable.php/baz/bat") - assert.Contains(t, strBody, "[CONTENT_TYPE] => text/plain") - assert.Contains(t, strBody, fmt.Sprintf("[QUERY_STRING] => foo=a&bar=b&i=%d#hash", i)) - assert.Contains(t, strBody, fmt.Sprintf("[REQUEST_URI] => /server-variable.php/baz/bat?foo=a&bar=b&i=%d#hash", i)) - assert.Contains(t, strBody, "[CONTENT_LENGTH]") - assert.Contains(t, strBody, "[REMOTE_ADDR]") - assert.Contains(t, strBody, "[REMOTE_PORT]") - assert.Contains(t, strBody, "[REQUEST_SCHEME] => http") - assert.Contains(t, strBody, "[DOCUMENT_URI]") - assert.Contains(t, strBody, "[AUTH_TYPE]") - assert.Contains(t, strBody, "[REMOTE_IDENT]") - assert.Contains(t, strBody, "[REQUEST_METHOD] => POST") - assert.Contains(t, strBody, "[SERVER_NAME] => example.com") - assert.Contains(t, strBody, "[SERVER_PROTOCOL] => HTTP/1.1") - assert.Contains(t, strBody, "[SCRIPT_FILENAME]") - assert.Contains(t, strBody, "[SERVER_SOFTWARE] => FrankenPHP") - assert.Contains(t, strBody, "[REQUEST_TIME_FLOAT]") - assert.Contains(t, strBody, "[REQUEST_TIME]") - assert.Contains(t, strBody, "[SERVER_PORT] => 80") + body, _ := testRequest(req, handler, t) + + assert.Contains(t, body, "[REMOTE_HOST]") + assert.Contains(t, body, "[REMOTE_USER] => kevin") + assert.Contains(t, body, "[PHP_AUTH_USER] => kevin") + assert.Contains(t, body, "[PHP_AUTH_PW] => password") + assert.Contains(t, body, "[HTTP_AUTHORIZATION] => Basic a2V2aW46cGFzc3dvcmQ=") + assert.Contains(t, body, "[DOCUMENT_ROOT]") + assert.Contains(t, body, "[PHP_SELF] => /server-variable.php/baz/bat") + assert.Contains(t, body, "[CONTENT_TYPE] => text/plain") + assert.Contains(t, body, fmt.Sprintf("[QUERY_STRING] => foo=a&bar=b&i=%d#hash", i)) + assert.Contains(t, body, fmt.Sprintf("[REQUEST_URI] => /server-variable.php/baz/bat?foo=a&bar=b&i=%d#hash", i)) + assert.Contains(t, body, "[CONTENT_LENGTH]") + assert.Contains(t, body, "[REMOTE_ADDR]") + assert.Contains(t, body, "[REMOTE_PORT]") + assert.Contains(t, body, "[REQUEST_SCHEME] => http") + assert.Contains(t, body, "[DOCUMENT_URI]") + assert.Contains(t, body, "[AUTH_TYPE]") + assert.Contains(t, body, "[REMOTE_IDENT]") + assert.Contains(t, body, "[REQUEST_METHOD] => POST") + assert.Contains(t, body, "[SERVER_NAME] => example.com") + assert.Contains(t, body, "[SERVER_PROTOCOL] => HTTP/1.1") + assert.Contains(t, body, "[SCRIPT_FILENAME]") + assert.Contains(t, body, "[SERVER_SOFTWARE] => FrankenPHP") + assert.Contains(t, body, "[REQUEST_TIME_FLOAT]") + assert.Contains(t, body, "[REQUEST_TIME]") + assert.Contains(t, body, "[SERVER_PORT] => 80") }, opts) } @@ -210,19 +220,12 @@ func testPathInfo(t *testing.T, opts *testOptions) { assert.NoError(t, err) } - req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/pathinfo/%d", i), nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - strBody := string(body) + body, _ := testGet(fmt.Sprintf("http://example.com/pathinfo/%d", i), handler, t) - assert.Contains(t, strBody, "[PATH_INFO] => /pathinfo") - assert.Contains(t, strBody, fmt.Sprintf("[REQUEST_URI] => /pathinfo/%d", i)) - assert.Contains(t, strBody, "[PATH_TRANSLATED] =>") - assert.Contains(t, strBody, "[SCRIPT_NAME] => /server-variable.php") + assert.Contains(t, body, "[PATH_INFO] => /pathinfo") + assert.Contains(t, body, fmt.Sprintf("[REQUEST_URI] => /pathinfo/%d", i)) + assert.Contains(t, body, "[PATH_TRANSLATED] =>") + assert.Contains(t, body, "[SCRIPT_NAME] => /server-variable.php") }, opts) } @@ -231,14 +234,9 @@ func TestHeaders_module(t *testing.T) { testHeaders(t, nil) } func TestHeaders_worker(t *testing.T) { testHeaders(t, &testOptions{workerScript: "headers.php"}) } func testHeaders(t *testing.T, opts *testOptions) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/headers.php?i=%d", i), nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, resp := testGet(fmt.Sprintf("http://example.com/headers.php?i=%d", i), handler, t) - assert.Equal(t, "Hello", string(body)) + assert.Equal(t, "Hello", body) assert.Equal(t, 201, resp.StatusCode) assert.Equal(t, "bar", resp.Header.Get("Foo")) assert.Equal(t, "bar2", resp.Header.Get("Foo2")) @@ -254,12 +252,7 @@ func TestResponseHeaders_worker(t *testing.T) { } func testResponseHeaders(t *testing.T, opts *testOptions) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/response-headers.php?i=%d", i), nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, resp := testGet(fmt.Sprintf("http://example.com/response-headers.php?i=%d", i), handler, t) if i%3 != 0 { assert.Equal(t, i+100, resp.StatusCode) @@ -267,11 +260,11 @@ func testResponseHeaders(t *testing.T, opts *testOptions) { assert.Equal(t, 200, resp.StatusCode) } - assert.Contains(t, string(body), "'X-Powered-By' => 'PH") - assert.Contains(t, string(body), "'Foo' => 'bar',") - assert.Contains(t, string(body), "'Foo2' => 'bar2',") - assert.Contains(t, string(body), fmt.Sprintf("'I' => '%d',", i)) - assert.NotContains(t, string(body), "Invalid") + assert.Contains(t, body, "'X-Powered-By' => 'PH") + assert.Contains(t, body, "'Foo' => 'bar',") + assert.Contains(t, body, "'Foo2' => 'bar2',") + assert.Contains(t, body, fmt.Sprintf("'I' => '%d',", i)) + assert.NotContains(t, body, "Invalid") }, opts) } @@ -279,14 +272,9 @@ func TestInput_module(t *testing.T) { testInput(t, nil) } func TestInput_worker(t *testing.T) { testInput(t, &testOptions{workerScript: "input.php"}) } func testInput(t *testing.T, opts *testOptions) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest("POST", "http://example.com/input.php", strings.NewReader(fmt.Sprintf("post data %d", i))) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, resp := testPost("http://example.com/input.php", fmt.Sprintf("post data %d", i), handler, t) - assert.Equal(t, fmt.Sprintf("post data %d", i), string(body)) + assert.Equal(t, fmt.Sprintf("post data %d", i), body) assert.Equal(t, "bar", resp.Header.Get("Foo")) }, opts) } @@ -300,16 +288,12 @@ func testPostSuperGlobals(t *testing.T, opts *testOptions) { formData := url.Values{"baz": {"bat"}, "i": {fmt.Sprintf("%d", i)}} req := httptest.NewRequest("POST", fmt.Sprintf("http://example.com/super-globals.php?foo=bar&iG=%d", i), strings.NewReader(formData.Encode())) req.Header.Set("Content-Type", strings.Clone("application/x-www-form-urlencoded")) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, _ := testRequest(req, handler, t) - assert.Contains(t, string(body), "'foo' => 'bar'") - assert.Contains(t, string(body), fmt.Sprintf("'i' => '%d'", i)) - assert.Contains(t, string(body), "'baz' => 'bat'") - assert.Contains(t, string(body), fmt.Sprintf("'iG' => '%d'", i)) + assert.Contains(t, body, "'foo' => 'bar'") + assert.Contains(t, body, fmt.Sprintf("'i' => '%d'", i)) + assert.Contains(t, body, "'baz' => 'bat'") + assert.Contains(t, body, fmt.Sprintf("'iG' => '%d'", i)) }, opts) } @@ -320,14 +304,10 @@ func testCookies(t *testing.T, opts *testOptions) { req := httptest.NewRequest("GET", "http://example.com/cookies.php", nil) req.AddCookie(&http.Cookie{Name: "foo", Value: "bar"}) req.AddCookie(&http.Cookie{Name: "i", Value: fmt.Sprintf("%d", i)}) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, _ := testRequest(req, handler, t) - assert.Contains(t, string(body), "'foo' => 'bar'") - assert.Contains(t, string(body), fmt.Sprintf("'i' => '%d'", i)) + assert.Contains(t, body, "'foo' => 'bar'") + assert.Contains(t, body, fmt.Sprintf("'i' => '%d'", i)) }, opts) } @@ -337,21 +317,17 @@ func TestMalformedCookie(t *testing.T) { req.Header.Add("Cookie", "foo =bar; ===;;==; .dot.=val ;\x00 ; PHPSESSID=1234") // Multiple Cookie header should be joined https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2.5 req.Header.Add("Cookie", "secondCookie=test; secondCookie=overwritten") - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, _ := testRequest(req, handler, t) - assert.Contains(t, string(body), "'foo_' => 'bar'") - assert.Contains(t, string(body), "'_dot_' => 'val '") + assert.Contains(t, body, "'foo_' => 'bar'") + assert.Contains(t, body, "'_dot_' => 'val '") // PHPSESSID should still be present since we remove the null byte - assert.Contains(t, string(body), "'PHPSESSID' => '1234'") + assert.Contains(t, body, "'PHPSESSID' => '1234'") // The cookie in the second headers should be present, // but it should not be overwritten by following values - assert.Contains(t, string(body), "'secondCookie' => 'test'") + assert.Contains(t, body, "'secondCookie' => 'test'") }, &testOptions{nbParallelRequests: 1}) } @@ -391,19 +367,14 @@ func TestPhpInfo_worker(t *testing.T) { testPhpInfo(t, &testOptions{workerScript func testPhpInfo(t *testing.T, opts *testOptions) { var logOnce sync.Once runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/phpinfo.php?i=%d", i), nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, _ := testGet(fmt.Sprintf("http://example.com/phpinfo.php?i=%d", i), handler, t) logOnce.Do(func() { - t.Log(string(body)) + t.Log(body) }) - assert.Contains(t, string(body), "frankenphp") - assert.Contains(t, string(body), fmt.Sprintf("i=%d", i)) + assert.Contains(t, body, "frankenphp") + assert.Contains(t, body, fmt.Sprintf("i=%d", i)) }, opts) } @@ -413,17 +384,12 @@ func TestPersistentObject_worker(t *testing.T) { } func testPersistentObject(t *testing.T, opts *testOptions) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/persistent-object.php?i=%d", i), nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, _ := testGet(fmt.Sprintf("http://example.com/persistent-object.php?i=%d", i), handler, t) assert.Equal(t, fmt.Sprintf(`request: %d class exists: 1 id: obj1 -object id: 1`, i), string(body)) +object id: 1`, i), body) }, opts) } @@ -433,15 +399,10 @@ func TestAutoloader_worker(t *testing.T) { } func testAutoloader(t *testing.T, opts *testOptions) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/autoloader.php?i=%d", i), nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, _ := testGet(fmt.Sprintf("http://example.com/autoloader.php?i=%d", i), handler, t) assert.Equal(t, fmt.Sprintf(`request %d -my_autoloader`, i), string(body)) +my_autoloader`, i), body) }, opts) } @@ -498,15 +459,10 @@ func TestException_worker(t *testing.T) { } func testException(t *testing.T, opts *testOptions) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/exception.php?i=%d", i), nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, _ := testGet(fmt.Sprintf("http://example.com/exception.php?i=%d", i), handler, t) - assert.Contains(t, string(body), "hello") - assert.Contains(t, string(body), fmt.Sprintf(`Uncaught Exception: request %d`, i)) + assert.Contains(t, body, "hello") + assert.Contains(t, body, fmt.Sprintf(`Uncaught Exception: request %d`, i)) }, opts) } @@ -585,18 +541,14 @@ func TestLargeRequest_worker(t *testing.T) { } func testLargeRequest(t *testing.T, opts *testOptions) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest( - "POST", + body, _ := testPost( fmt.Sprintf("http://example.com/large-request.php?i=%d", i), - strings.NewReader(strings.Repeat("f", 6_048_576)), + strings.Repeat("f", 6_048_576), + handler, + t, ) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - assert.Contains(t, string(body), fmt.Sprintf("Request body size: 6048576 (%d)", i)) + assert.Contains(t, body, fmt.Sprintf("Request body size: 6048576 (%d)", i)) }, opts) } @@ -616,14 +568,8 @@ func TestFiberNonCgo_worker(t *testing.T) { } func testFiberNoCgo(t *testing.T, opts *testOptions) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/fiber-no-cgo.php?i=%d", i), nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - assert.Equal(t, string(body), fmt.Sprintf("Fiber %d", i)) + body, _ := testGet(fmt.Sprintf("http://example.com/fiber-no-cgo.php?i=%d", i), handler, t) + assert.Equal(t, body, fmt.Sprintf("Fiber %d", i)) }, opts) } @@ -633,14 +579,8 @@ func TestFiberBasic_worker(t *testing.T) { } func testFiberBasic(t *testing.T, opts *testOptions) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/fiber-basic.php?i=%d", i), nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - assert.Equal(t, string(body), fmt.Sprintf("Fiber %d", i)) + body, _ := testGet(fmt.Sprintf("http://example.com/fiber-basic.php?i=%d", i), handler, t) + assert.Equal(t, body, fmt.Sprintf("Fiber %d", i)) }, opts) } @@ -653,27 +593,17 @@ func testRequestHeaders(t *testing.T, opts *testOptions) { req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/request-headers.php?i=%d", i), nil) req.Header.Add(strings.Clone("Content-Type"), strings.Clone("text/plain")) req.Header.Add(strings.Clone("Frankenphp-I"), strings.Clone(strconv.Itoa(i))) + body, _ := testRequest(req, handler, t) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - assert.Contains(t, string(body), "[Content-Type] => text/plain") - assert.Contains(t, string(body), fmt.Sprintf("[Frankenphp-I] => %d", i)) + assert.Contains(t, body, "[Content-Type] => text/plain") + assert.Contains(t, body, fmt.Sprintf("[Frankenphp-I] => %d", i)) }, opts) } func TestFailingWorker(t *testing.T) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest("GET", "http://example.com/failing-worker.php", nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - assert.Contains(t, string(body), "ok") + body, _ := testGet("http://example.com/failing-worker.php", handler, t) + assert.Contains(t, body, "ok") }, &testOptions{workerScript: "failing-worker.php"}) } @@ -689,12 +619,7 @@ func testEnv(t *testing.T, opts *testOptions) { assert.NoError(t, os.Setenv("EMPTY", "")) runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/env/test-env.php?var=%d", i), nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, _ := testGet(fmt.Sprintf("http://example.com/env/test-env.php?var=%d", i), handler, t) // execute the script as regular php script cmd := exec.Command("php", "testdata/env/test-env.php", strconv.Itoa(i)) @@ -704,18 +629,18 @@ func testEnv(t *testing.T, opts *testOptions) { stdoutStderr = []byte("Set MY_VAR successfully.\nMY_VAR = HelloWorld\nUnset MY_VAR successfully.\nMY_VAR is unset.\nMY_VAR set to empty successfully.\nMY_VAR = \nUnset NON_EXISTING_VAR successfully.\n") } - assert.Equal(t, string(stdoutStderr), string(body)) + assert.Equal(t, string(stdoutStderr), body) }, opts) } func TestEnvIsResetInNonWorkerMode(t *testing.T) { assert.NoError(t, os.Setenv("test", "")) runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - putResult := fetchBody("GET", fmt.Sprintf("http://example.com/env/putenv.php?key=test&put=%d", i), handler) + putResult, _ := testGet(fmt.Sprintf("http://example.com/env/putenv.php?key=test&put=%d", i), handler, t) assert.Equal(t, fmt.Sprintf("test=%d", i), putResult, "putenv and then echo getenv") - getResult := fetchBody("GET", "http://example.com/env/putenv.php?key=test", handler) + getResult, _ := testGet("http://example.com/env/putenv.php?key=test", handler, t) assert.Equal(t, "test=", getResult, "putenv should be reset across requests") }, &testOptions{}) @@ -725,11 +650,11 @@ func TestEnvIsResetInNonWorkerMode(t *testing.T) { func TestEnvIsNotResetInWorkerMode(t *testing.T) { assert.NoError(t, os.Setenv("index", "")) runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { - putResult := fetchBody("GET", fmt.Sprintf("http://example.com/env/remember-env.php?index=%d", i), handler) + putResult, _ := testGet(fmt.Sprintf("http://example.com/env/remember-env.php?index=%d", i), handler, t) assert.Equal(t, "success", putResult, "putenv and then echo getenv") - getResult := fetchBody("GET", "http://example.com/env/remember-env.php", handler) + getResult, _ := testGet("http://example.com/env/remember-env.php", handler, t) assert.Equal(t, "success", getResult, "putenv should not be reset across worker requests") }, &testOptions{workerScript: "env/remember-env.php"}) @@ -739,7 +664,7 @@ func TestEnvIsNotResetInWorkerMode(t *testing.T) { func TestModificationsToEnvPersistAcrossRequests(t *testing.T) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, i int) { for j := 0; j < 3; j++ { - result := fetchBody("GET", "http://example.com/env/overwrite-env.php", handler) + result, _ := testGet("http://example.com/env/overwrite-env.php", handler, t) assert.Equal(t, "custom_value", result, "a var directly added to $_ENV should persist") } }, &testOptions{ @@ -765,11 +690,7 @@ func testFileUpload(t *testing.T, opts *testOptions) { req := httptest.NewRequest("POST", "http://example.com/file-upload.php", requestBody) req.Header.Add("Content-Type", writer.FormDataContentType()) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, _ := testRequest(req, handler, t) assert.Contains(t, string(body), "Upload OK") }, opts) @@ -1015,15 +936,10 @@ func testRejectInvalidHeaders(t *testing.T, opts *testOptions) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, _ int) { req := httptest.NewRequest("GET", "http://example.com/headers.php", nil) req.Header.Add(header[0], header[1]) - - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, resp := testRequest(req, handler, t) assert.Equal(t, 400, resp.StatusCode) - assert.Contains(t, string(body), "invalid") + assert.Contains(t, body, "invalid") }, opts) } } @@ -1035,11 +951,7 @@ func TestFlushEmptyRespnse_worker(t *testing.T) { func testFlushEmptyResponse(t *testing.T, opts *testOptions) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, _ int) { - req := httptest.NewRequest("GET", "http://example.com/only-headers.php", nil) - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() + _, resp := testGet("http://example.com/only-headers.php", handler, t) assert.Equal(t, 204, resp.StatusCode) }, opts) } @@ -1048,13 +960,13 @@ func testFlushEmptyResponse(t *testing.T, opts *testOptions) { // Make sure referenced streams are not cleaned up func TestFileStreamInWorkerMode(t *testing.T) { runTest(t, func(handler func(http.ResponseWriter, *http.Request), _ *httptest.Server, _ int) { - resp1 := fetchBody("GET", "http://example.com/file-stream.php", handler) + resp1, _ := testGet("http://example.com/file-stream.php", handler, t) assert.Equal(t, resp1, "word1") - resp2 := fetchBody("GET", "http://example.com/file-stream.php", handler) + resp2, _ := testGet("http://example.com/file-stream.php", handler, t) assert.Equal(t, resp2, "word2") - resp3 := fetchBody("GET", "http://example.com/file-stream.php", handler) + resp3, _ := testGet("http://example.com/file-stream.php", handler, t) assert.Equal(t, resp3, "word3") }, &testOptions{workerScript: "file-stream.php", nbParallelRequests: 1, nbWorkers: 1}) } @@ -1074,38 +986,23 @@ func FuzzRequest(f *testing.F) { req.URL = &url.URL{RawQuery: "test=" + fuzzedString, Path: "/server-variable.php/" + fuzzedString} req.Header.Add(strings.Clone("Fuzzed"), strings.Clone(fuzzedString)) req.Header.Add(strings.Clone("Content-Type"), fuzzedString) - - w := httptest.NewRecorder() - handler(w, req) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) + body, resp := testRequest(req, handler, t) // The response status must be 400 if the request path contains null bytes if strings.Contains(req.URL.Path, "\x00") { assert.Equal(t, 400, resp.StatusCode) - assert.Contains(t, string(body), "Invalid request path") + assert.Contains(t, body, "Invalid request path") return } // The fuzzed string must be present in the path - assert.Contains(t, string(body), fmt.Sprintf("[PATH_INFO] => /%s", fuzzedString)) - assert.Contains(t, string(body), fmt.Sprintf("[PATH_TRANSLATED] => %s", filepath.Join(absPath, fuzzedString))) + assert.Contains(t, body, fmt.Sprintf("[PATH_INFO] => /%s", fuzzedString)) + assert.Contains(t, body, fmt.Sprintf("[PATH_TRANSLATED] => %s", filepath.Join(absPath, fuzzedString))) // Headers should always be present even if empty - assert.Contains(t, string(body), fmt.Sprintf("[CONTENT_TYPE] => %s", fuzzedString)) - assert.Contains(t, string(body), fmt.Sprintf("[HTTP_FUZZED] => %s", fuzzedString)) + assert.Contains(t, body, fmt.Sprintf("[CONTENT_TYPE] => %s", fuzzedString)) + assert.Contains(t, body, fmt.Sprintf("[HTTP_FUZZED] => %s", fuzzedString)) }, &testOptions{workerScript: "request-headers.php"}) }) } - -func fetchBody(method string, url string, handler func(http.ResponseWriter, *http.Request)) string { - req := httptest.NewRequest(method, url, nil) - w := httptest.NewRecorder() - handler(w, req) - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - return string(body) -} diff --git a/watcher_test.go b/watcher_test.go index afd4e8c88c..ea5134a9e8 100644 --- a/watcher_test.go +++ b/watcher_test.go @@ -42,14 +42,14 @@ func TestWorkersShouldNotReloadOnExcludingPattern(t *testing.T) { func pollForWorkerReset(t *testing.T, handler func(http.ResponseWriter, *http.Request), limit int) bool { // first we make an initial request to start the request counter - body := fetchBody("GET", "http://example.com/worker-with-counter.php", handler) + body, _ := testGet("http://example.com/worker-with-counter.php", handler, t) assert.Equal(t, "requests:1", body) // now we spam file updates and check if the request counter resets for i := 0; i < limit; i++ { updateTestFile("./testdata/files/test.txt", "updated", t) time.Sleep(pollingTime * time.Millisecond) - body := fetchBody("GET", "http://example.com/worker-with-counter.php", handler) + body, _ := testGet("http://example.com/worker-with-counter.php", handler, t) if body == "requests:1" { return true }