diff --git a/drive/driveimpl/compositedav/compositedav.go b/drive/driveimpl/compositedav/compositedav.go index 8b41871ad49ce..e5c16f1a107c2 100644 --- a/drive/driveimpl/compositedav/compositedav.go +++ b/drive/driveimpl/compositedav/compositedav.go @@ -93,8 +93,17 @@ var cacheInvalidatingMethods = map[string]bool{ // ServeHTTP implements http.Handler. func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method == "PROPFIND" { - h.handlePROPFIND(w, r) + pathComponents := shared.CleanAndSplit(r.URL.Path) + mpl := h.maxPathLength(r) + + rewriteIfHeader(r, pathComponents, mpl) + + switch r.Method { + case "PROPFIND": + h.handlePROPFIND(w, r, pathComponents, mpl) + return + case "LOCK": + h.handleLOCK(w, r, pathComponents, mpl) return } @@ -107,9 +116,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.StatCache.invalidate() } - mpl := h.maxPathLength(r) - pathComponents := shared.CleanAndSplit(r.URL.Path) - if len(pathComponents) >= mpl { h.delegate(mpl, pathComponents[mpl-1:], w, r) return diff --git a/drive/driveimpl/compositedav/propfind.go b/drive/driveimpl/compositedav/propfind.go deleted file mode 100644 index 5e6ccfa0bb8e1..0000000000000 --- a/drive/driveimpl/compositedav/propfind.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package compositedav - -import ( - "bytes" - "fmt" - "math" - "net/http" - "regexp" - - "tailscale.com/drive/driveimpl/shared" -) - -var ( - hrefRegex = regexp.MustCompile(`(?s)/?([^<]*)/?`) -) - -func (h *Handler) handlePROPFIND(w http.ResponseWriter, r *http.Request) { - pathComponents := shared.CleanAndSplit(r.URL.Path) - mpl := h.maxPathLength(r) - if !shared.IsRoot(r.URL.Path) && len(pathComponents)+getDepth(r) > mpl { - // Delegate to a Child. - depth := getDepth(r) - - status, result := h.StatCache.getOr(r.URL.Path, depth, func() (int, []byte) { - // Use a buffering ResponseWriter so that we can manipulate the result. - // The only thing we use from the original ResponseWriter is Header(). - bw := &bufferingResponseWriter{ResponseWriter: w} - - mpl := h.maxPathLength(r) - h.delegate(mpl, pathComponents[mpl-1:], bw, r) - - // Fixup paths to add the requested path as a prefix. - pathPrefix := shared.Join(pathComponents[0:mpl]...) - b := hrefRegex.ReplaceAll(bw.buf.Bytes(), []byte(fmt.Sprintf("%s/$1", pathPrefix))) - - return bw.status, b - }) - - w.Header().Del("Content-Length") - w.WriteHeader(status) - if result != nil { - w.Write(result) - } - return - } - - h.handle(w, r) -} - -func getDepth(r *http.Request) int { - switch r.Header.Get("Depth") { - case "0": - return 0 - case "1": - return 1 - case "infinity": - return math.MaxInt - } - return 0 -} - -type bufferingResponseWriter struct { - http.ResponseWriter - status int - buf bytes.Buffer -} - -func (bw *bufferingResponseWriter) WriteHeader(statusCode int) { - bw.status = statusCode -} - -func (bw *bufferingResponseWriter) Write(p []byte) (int, error) { - return bw.buf.Write(p) -} diff --git a/drive/driveimpl/compositedav/rewriting.go b/drive/driveimpl/compositedav/rewriting.go new file mode 100644 index 0000000000000..3dcc9be0a7f9a --- /dev/null +++ b/drive/driveimpl/compositedav/rewriting.go @@ -0,0 +1,112 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package compositedav + +import ( + "bytes" + "fmt" + "math" + "net/http" + "regexp" + "strings" + + "tailscale.com/drive/driveimpl/shared" +) + +var ( + responseHrefRegex = regexp.MustCompile(`(?s)()/?([^<]*)/?`) + ifHrefRegex = regexp.MustCompile(`^<(https?://[^/]+)?([^>]+)>`) +) + +func (h *Handler) handlePROPFIND(w http.ResponseWriter, r *http.Request, pathComponents []string, mpl int) { + if shouldDelegateToChild(r, pathComponents, mpl) { + // Delegate to a Child. + depth := getDepth(r) + + status, result := h.StatCache.getOr(r.URL.Path, depth, func() (int, []byte) { + return h.delegateRewriting(w, r, pathComponents, mpl) + }) + + respondRewritten(w, status, result) + return + } + + h.handle(w, r) +} + +func (h *Handler) handleLOCK(w http.ResponseWriter, r *http.Request, pathComponents []string, mpl int) { + if shouldDelegateToChild(r, pathComponents, mpl) { + // Delegate to a Child. + status, result := h.delegateRewriting(w, r, pathComponents, mpl) + respondRewritten(w, status, result) + return + } + + http.Error(w, "locking of top level directories is not allowed", http.StatusMethodNotAllowed) +} + +func shouldDelegateToChild(r *http.Request, pathComponents []string, mpl int) bool { + return !shared.IsRoot(r.URL.Path) && len(pathComponents)+getDepth(r) > mpl +} + +func (h *Handler) delegateRewriting(w http.ResponseWriter, r *http.Request, pathComponents []string, mpl int) (int, []byte) { + // Use a buffering ResponseWriter so that we can manipulate the result. + // The only thing we use from the original ResponseWriter is Header(). + bw := &bufferingResponseWriter{ResponseWriter: w} + + h.delegate(mpl, pathComponents[mpl-1:], bw, r) + + // Fixup paths to add the requested path as a prefix, escaped for inclusion in XML. + pp := shared.EscapeForXML(shared.Join(pathComponents[0:mpl]...)) + b := responseHrefRegex.ReplaceAll(bw.buf.Bytes(), []byte(fmt.Sprintf("$1%s/$3", pp))) + return bw.status, b +} + +func respondRewritten(w http.ResponseWriter, status int, result []byte) { + w.Header().Del("Content-Length") + w.WriteHeader(status) + if result != nil { + w.Write(result) + } +} + +func getDepth(r *http.Request) int { + switch r.Header.Get("Depth") { + case "0": + return 0 + case "1": + return 1 + case "infinity": + return math.MaxInt16 // a really large number, but not infinity (avoids wrapping when we do arithmetic with this) + } + return 0 +} + +type bufferingResponseWriter struct { + http.ResponseWriter + status int + buf bytes.Buffer +} + +func (bw *bufferingResponseWriter) WriteHeader(statusCode int) { + bw.status = statusCode +} + +func (bw *bufferingResponseWriter) Write(p []byte) (int, error) { + return bw.buf.Write(p) +} + +func rewriteIfHeader(r *http.Request, pathComponents []string, mpl int) { + ih := r.Header.Get("If") + if ih == "" { + return + } + matches := ifHrefRegex.FindStringSubmatch(ih) + if len(matches) == 3 { + pp := shared.JoinEscaped(pathComponents[0:mpl]...) + p := strings.Replace(shared.JoinEscaped(pathComponents...), pp, "", 1) + nih := ifHrefRegex.ReplaceAllString(ih, fmt.Sprintf("<%s>", p)) + r.Header.Set("If", nih) + } +} diff --git a/drive/driveimpl/drive_test.go b/drive/driveimpl/drive_test.go index 8e9d1a557a1c5..f9ad4b4428c4a 100644 --- a/drive/driveimpl/drive_test.go +++ b/drive/driveimpl/drive_test.go @@ -14,6 +14,7 @@ import ( "os" "path" "path/filepath" + "regexp" "slices" "strings" "sync" @@ -30,14 +31,19 @@ import ( const ( domain = `test$%domain.com` - remote1 = `rem ote$%1` - remote2 = `_rem ote$%2` - share11 = `sha re$%11` - share12 = `_sha re$%12` - file111 = `fi le$%111.txt` + remote1 = `rem ote$%<>1` + remote2 = `_rem ote$%<>2` + share11 = `sha re$%<>11` + share12 = `_sha re$%<>12` + file111 = `fi le$%<>111.txt` file112 = `file112.txt` ) +var ( + lockRootRegex = regexp.MustCompile(`/?([^<]*)/?`) + lockTokenRegex = regexp.MustCompile(`([0-9]+)/?`) +) + func init() { // set AllowShareAs() to false so that we don't try to use sub-processes // for access files on disk. @@ -145,6 +151,206 @@ func TestSecretTokenAuth(t *testing.T) { } } +func TestLOCK(t *testing.T) { + s := newSystem(t) + + s.addRemote(remote1) + s.addShare(remote1, share11, drive.PermissionReadWrite) + s.writeFile("writing file to read/write remote should succeed", remote1, share11, file111, "hello world", true) + + client := &http.Client{ + Transport: &http.Transport{DisableKeepAlives: true}, + } + + u := fmt.Sprintf("http://%s/%s/%s/%s/%s", + s.local.l.Addr(), + url.PathEscape(domain), + url.PathEscape(remote1), + url.PathEscape(share11), + url.PathEscape(file111)) + + // First acquire a lock with a short timeout + req, err := http.NewRequest("LOCK", u, strings.NewReader(lockBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Depth", "infinity") + req.Header.Set("Timeout", "Second-1") + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("expected LOCK to succeed, but got status %d", resp.StatusCode) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + submatches := lockRootRegex.FindStringSubmatch(string(body)) + if len(submatches) != 2 { + t.Fatal("failed to find lockroot") + } + want := shared.EscapeForXML(pathTo(remote1, share11, file111)) + got := submatches[1] + if got != want { + t.Fatalf("want lockroot %q, got %q", want, got) + } + + submatches = lockTokenRegex.FindStringSubmatch(string(body)) + if len(submatches) != 2 { + t.Fatal("failed to find locktoken") + } + lockToken := submatches[1] + ifHeader := fmt.Sprintf("<%s> (<%s>)", u, lockToken) + + // Then refresh the lock with a longer timeout + req, err = http.NewRequest("LOCK", u, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Depth", "infinity") + req.Header.Set("Timeout", "Second-600") + req.Header.Set("If", ifHeader) + resp, err = client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("expected LOCK refresh to succeed, but got status %d", resp.StatusCode) + } + body, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + submatches = lockRootRegex.FindStringSubmatch(string(body)) + if len(submatches) != 2 { + t.Fatal("failed to find lockroot after refresh") + } + want = shared.EscapeForXML(pathTo(remote1, share11, file111)) + got = submatches[1] + if got != want { + t.Fatalf("want lockroot after refresh %q, got %q", want, got) + } + + submatches = lockTokenRegex.FindStringSubmatch(string(body)) + if len(submatches) != 2 { + t.Fatal("failed to find locktoken after refresh") + } + if submatches[1] != lockToken { + t.Fatalf("on refresh, lock token changed from %q to %q", lockToken, submatches[1]) + } + + // Then wait past the original timeout, then try to delete without the lock + // (should fail) + time.Sleep(1 * time.Second) + req, err = http.NewRequest("DELETE", u, nil) + if err != nil { + log.Fatal(err) + } + resp, err = client.Do(req) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 423 { + t.Fatalf("deleting without lock token should fail with 423, but got %d", resp.StatusCode) + } + + // Then delete with the lock (should succeed) + req, err = http.NewRequest("DELETE", u, nil) + if err != nil { + log.Fatal(err) + } + req.Header.Set("If", ifHeader) + resp, err = client.Do(req) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 204 { + t.Fatalf("deleting with lock token should have succeeded with 204, but got %d", resp.StatusCode) + } +} + +func TestUNLOCK(t *testing.T) { + s := newSystem(t) + + s.addRemote(remote1) + s.addShare(remote1, share11, drive.PermissionReadWrite) + s.writeFile("writing file to read/write remote should succeed", remote1, share11, file111, "hello world", true) + + client := &http.Client{ + Transport: &http.Transport{DisableKeepAlives: true}, + } + + u := fmt.Sprintf("http://%s/%s/%s/%s/%s", + s.local.l.Addr(), + url.PathEscape(domain), + url.PathEscape(remote1), + url.PathEscape(share11), + url.PathEscape(file111)) + + // Acquire a lock + req, err := http.NewRequest("LOCK", u, strings.NewReader(lockBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Depth", "infinity") + req.Header.Set("Timeout", "Second-600") + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("expected LOCK to succeed, but got status %d", resp.StatusCode) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + submatches := lockTokenRegex.FindStringSubmatch(string(body)) + if len(submatches) != 2 { + t.Fatal("failed to find locktoken") + } + lockToken := submatches[1] + + // Release the lock + req, err = http.NewRequest("UNLOCK", u, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Lock-Token", fmt.Sprintf("<%s>", lockToken)) + resp, err = client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 204 { + t.Fatalf("expected UNLOCK to succeed with a 204, but got status %d", resp.StatusCode) + } + + // Then delete without the lock (should succeed) + req, err = http.NewRequest("DELETE", u, nil) + if err != nil { + log.Fatal(err) + } + resp, err = client.Do(req) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 204 { + t.Fatalf("deleting without lock should have succeeded with 204, but got %d", resp.StatusCode) + } +} + type local struct { l net.Listener fs *FileSystemForLocal @@ -486,3 +692,9 @@ func (a *noopAuthenticator) Clone() gowebdav.Authenticator { func (a *noopAuthenticator) Close() error { return nil } + +const lockBody = ` + + + +` diff --git a/drive/driveimpl/fileserver.go b/drive/driveimpl/fileserver.go index e9ea7331e8f0e..0067c1cc7db63 100644 --- a/drive/driveimpl/fileserver.go +++ b/drive/driveimpl/fileserver.go @@ -151,6 +151,9 @@ func (s *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) return } + // WebDAV's locking code compares the lock resources with the request's + // host header, set this to empty to avoid mismatches. + r.Host = "" h.ServeHTTP(w, r) } diff --git a/drive/driveimpl/shared/xml.go b/drive/driveimpl/shared/xml.go new file mode 100644 index 0000000000000..79fd0885dd500 --- /dev/null +++ b/drive/driveimpl/shared/xml.go @@ -0,0 +1,16 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package shared + +import ( + "bytes" + "encoding/xml" +) + +// EscapeForXML escapes the given string for use in XML text. +func EscapeForXML(s string) string { + result := bytes.NewBuffer(nil) + xml.Escape(result, []byte(s)) + return result.String() +}