diff --git a/golink.go b/golink.go index 151b3d2..1b01977 100644 --- a/golink.go +++ b/golink.go @@ -384,7 +384,13 @@ func deleteLinkStats(link *Link) { // requests. It redirects all requests to the HTTPs version of the same URL. func redirectHandler(hostname string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, (&url.URL{Scheme: "https", Host: hostname, Path: r.URL.Path}).String(), http.StatusFound) + u := &url.URL{ + Scheme: "https", + Host: hostname, + Path: r.URL.Path, + RawQuery: r.URL.RawQuery, + } + http.Redirect(w, r, u.String(), http.StatusFound) }) } diff --git a/golink_test.go b/golink_test.go index 12c0263..5bfe131 100644 --- a/golink_test.go +++ b/golink_test.go @@ -637,3 +637,16 @@ func TestNoHSTSShortDomain(t *testing.T) { }) } } + +func TestHTTPSRedirectHandlerWithQuery(t *testing.T) { + h := redirectHandler("foobar.com") + r := httptest.NewRequest("GET", "http://example.com/?query=bar", nil) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + if w.Code != http.StatusFound { + t.Errorf("got %d; want %d", w.Code, http.StatusFound) + } + if w.Header().Get("Location") != "https://foobar.com/?query=bar" { + t.Errorf("got %q; want %q", w.Header().Get("Location"), "https://foobar.com/?query=bar") + } +}