diff --git a/drive/driveimpl/compositedav/compositedav.go b/drive/driveimpl/compositedav/compositedav.go
index 8b41871ad49ce7..e5c16f1a107c25 100644
--- a/drive/driveimpl/compositedav/compositedav.go
+++ b/drive/driveimpl/compositedav/compositedav.go
@@ -93,8 +93,17 @@ var cacheInvalidatingMethods = map[string]bool{
// ServeHTTP implements http.Handler.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- if r.Method == "PROPFIND" {
- h.handlePROPFIND(w, r)
+ pathComponents := shared.CleanAndSplit(r.URL.Path)
+ mpl := h.maxPathLength(r)
+
+ rewriteIfHeader(r, pathComponents, mpl)
+
+ switch r.Method {
+ case "PROPFIND":
+ h.handlePROPFIND(w, r, pathComponents, mpl)
+ return
+ case "LOCK":
+ h.handleLOCK(w, r, pathComponents, mpl)
return
}
@@ -107,9 +116,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.StatCache.invalidate()
}
- mpl := h.maxPathLength(r)
- pathComponents := shared.CleanAndSplit(r.URL.Path)
-
if len(pathComponents) >= mpl {
h.delegate(mpl, pathComponents[mpl-1:], w, r)
return
diff --git a/drive/driveimpl/compositedav/propfind.go b/drive/driveimpl/compositedav/propfind.go
deleted file mode 100644
index 5e6ccfa0bb8e10..00000000000000
--- a/drive/driveimpl/compositedav/propfind.go
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package compositedav
-
-import (
- "bytes"
- "fmt"
- "math"
- "net/http"
- "regexp"
-
- "tailscale.com/drive/driveimpl/shared"
-)
-
-var (
- hrefRegex = regexp.MustCompile(`(?s)/?([^<]*)/?`)
-)
-
-func (h *Handler) handlePROPFIND(w http.ResponseWriter, r *http.Request) {
- pathComponents := shared.CleanAndSplit(r.URL.Path)
- mpl := h.maxPathLength(r)
- if !shared.IsRoot(r.URL.Path) && len(pathComponents)+getDepth(r) > mpl {
- // Delegate to a Child.
- depth := getDepth(r)
-
- status, result := h.StatCache.getOr(r.URL.Path, depth, func() (int, []byte) {
- // Use a buffering ResponseWriter so that we can manipulate the result.
- // The only thing we use from the original ResponseWriter is Header().
- bw := &bufferingResponseWriter{ResponseWriter: w}
-
- mpl := h.maxPathLength(r)
- h.delegate(mpl, pathComponents[mpl-1:], bw, r)
-
- // Fixup paths to add the requested path as a prefix.
- pathPrefix := shared.Join(pathComponents[0:mpl]...)
- b := hrefRegex.ReplaceAll(bw.buf.Bytes(), []byte(fmt.Sprintf("%s/$1", pathPrefix)))
-
- return bw.status, b
- })
-
- w.Header().Del("Content-Length")
- w.WriteHeader(status)
- if result != nil {
- w.Write(result)
- }
- return
- }
-
- h.handle(w, r)
-}
-
-func getDepth(r *http.Request) int {
- switch r.Header.Get("Depth") {
- case "0":
- return 0
- case "1":
- return 1
- case "infinity":
- return math.MaxInt
- }
- return 0
-}
-
-type bufferingResponseWriter struct {
- http.ResponseWriter
- status int
- buf bytes.Buffer
-}
-
-func (bw *bufferingResponseWriter) WriteHeader(statusCode int) {
- bw.status = statusCode
-}
-
-func (bw *bufferingResponseWriter) Write(p []byte) (int, error) {
- return bw.buf.Write(p)
-}
diff --git a/drive/driveimpl/compositedav/rewriting.go b/drive/driveimpl/compositedav/rewriting.go
new file mode 100644
index 00000000000000..3dcc9be0a7f9ad
--- /dev/null
+++ b/drive/driveimpl/compositedav/rewriting.go
@@ -0,0 +1,112 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package compositedav
+
+import (
+ "bytes"
+ "fmt"
+ "math"
+ "net/http"
+ "regexp"
+ "strings"
+
+ "tailscale.com/drive/driveimpl/shared"
+)
+
+var (
+ responseHrefRegex = regexp.MustCompile(`(?s)()/?([^<]*)/?`)
+ ifHrefRegex = regexp.MustCompile(`^<(https?://[^/]+)?([^>]+)>`)
+)
+
+func (h *Handler) handlePROPFIND(w http.ResponseWriter, r *http.Request, pathComponents []string, mpl int) {
+ if shouldDelegateToChild(r, pathComponents, mpl) {
+ // Delegate to a Child.
+ depth := getDepth(r)
+
+ status, result := h.StatCache.getOr(r.URL.Path, depth, func() (int, []byte) {
+ return h.delegateRewriting(w, r, pathComponents, mpl)
+ })
+
+ respondRewritten(w, status, result)
+ return
+ }
+
+ h.handle(w, r)
+}
+
+func (h *Handler) handleLOCK(w http.ResponseWriter, r *http.Request, pathComponents []string, mpl int) {
+ if shouldDelegateToChild(r, pathComponents, mpl) {
+ // Delegate to a Child.
+ status, result := h.delegateRewriting(w, r, pathComponents, mpl)
+ respondRewritten(w, status, result)
+ return
+ }
+
+ http.Error(w, "locking of top level directories is not allowed", http.StatusMethodNotAllowed)
+}
+
+func shouldDelegateToChild(r *http.Request, pathComponents []string, mpl int) bool {
+ return !shared.IsRoot(r.URL.Path) && len(pathComponents)+getDepth(r) > mpl
+}
+
+func (h *Handler) delegateRewriting(w http.ResponseWriter, r *http.Request, pathComponents []string, mpl int) (int, []byte) {
+ // Use a buffering ResponseWriter so that we can manipulate the result.
+ // The only thing we use from the original ResponseWriter is Header().
+ bw := &bufferingResponseWriter{ResponseWriter: w}
+
+ h.delegate(mpl, pathComponents[mpl-1:], bw, r)
+
+ // Fixup paths to add the requested path as a prefix, escaped for inclusion in XML.
+ pp := shared.EscapeForXML(shared.Join(pathComponents[0:mpl]...))
+ b := responseHrefRegex.ReplaceAll(bw.buf.Bytes(), []byte(fmt.Sprintf("$1%s/$3", pp)))
+ return bw.status, b
+}
+
+func respondRewritten(w http.ResponseWriter, status int, result []byte) {
+ w.Header().Del("Content-Length")
+ w.WriteHeader(status)
+ if result != nil {
+ w.Write(result)
+ }
+}
+
+func getDepth(r *http.Request) int {
+ switch r.Header.Get("Depth") {
+ case "0":
+ return 0
+ case "1":
+ return 1
+ case "infinity":
+ return math.MaxInt16 // a really large number, but not infinity (avoids wrapping when we do arithmetic with this)
+ }
+ return 0
+}
+
+type bufferingResponseWriter struct {
+ http.ResponseWriter
+ status int
+ buf bytes.Buffer
+}
+
+func (bw *bufferingResponseWriter) WriteHeader(statusCode int) {
+ bw.status = statusCode
+}
+
+func (bw *bufferingResponseWriter) Write(p []byte) (int, error) {
+ return bw.buf.Write(p)
+}
+
+func rewriteIfHeader(r *http.Request, pathComponents []string, mpl int) {
+ ih := r.Header.Get("If")
+ if ih == "" {
+ return
+ }
+ matches := ifHrefRegex.FindStringSubmatch(ih)
+ if len(matches) == 3 {
+ pp := shared.JoinEscaped(pathComponents[0:mpl]...)
+ p := strings.Replace(shared.JoinEscaped(pathComponents...), pp, "", 1)
+ nih := ifHrefRegex.ReplaceAllString(ih, fmt.Sprintf("<%s>", p))
+ r.Header.Set("If", nih)
+ }
+}
diff --git a/drive/driveimpl/drive_test.go b/drive/driveimpl/drive_test.go
index 8e9d1a557a1c53..03570296ff62cc 100644
--- a/drive/driveimpl/drive_test.go
+++ b/drive/driveimpl/drive_test.go
@@ -14,6 +14,7 @@ import (
"os"
"path"
"path/filepath"
+ "regexp"
"slices"
"strings"
"sync"
@@ -30,14 +31,19 @@ import (
const (
domain = `test$%domain.com`
- remote1 = `rem ote$%1`
- remote2 = `_rem ote$%2`
- share11 = `sha re$%11`
- share12 = `_sha re$%12`
- file111 = `fi le$%111.txt`
+ remote1 = `rem ote$%<>1`
+ remote2 = `_rem ote$%<>2`
+ share11 = `sha re$%<>11`
+ share12 = `_sha re$%<>12`
+ file111 = `fi le$%<>111.txt`
file112 = `file112.txt`
)
+var (
+ lockrootRegex = regexp.MustCompile(`/?([^<]*)/?`)
+ lockTokenRegex = regexp.MustCompile(`([0-9]+)/?`)
+)
+
func init() {
// set AllowShareAs() to false so that we don't try to use sub-processes
// for access files on disk.
@@ -145,6 +151,206 @@ func TestSecretTokenAuth(t *testing.T) {
}
}
+func TestLOCK(t *testing.T) {
+ s := newSystem(t)
+
+ s.addRemote(remote1)
+ s.addShare(remote1, share11, drive.PermissionReadWrite)
+ s.writeFile("writing file to read/write remote should succeed", remote1, share11, file111, "hello world", true)
+
+ client := &http.Client{
+ Transport: &http.Transport{DisableKeepAlives: true},
+ }
+
+ u := fmt.Sprintf("http://%s/%s/%s/%s/%s",
+ s.local.l.Addr(),
+ url.PathEscape(domain),
+ url.PathEscape(remote1),
+ url.PathEscape(share11),
+ url.PathEscape(file111))
+
+ // First acquire a lock with a short timeout
+ req, err := http.NewRequest("LOCK", u, strings.NewReader(lockBody))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Header.Set("Depth", "infinity")
+ req.Header.Set("Timeout", "Second-1")
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ t.Fatalf("expected LOCK to succeed, but got status %d", resp.StatusCode)
+ }
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ submatches := lockrootRegex.FindStringSubmatch(string(body))
+ if len(submatches) != 2 {
+ t.Fatal("failed to find lockroot")
+ }
+ want := shared.EscapeForXML(pathTo(remote1, share11, file111))
+ got := submatches[1]
+ if got != want {
+ t.Fatalf("want lockroot %q, got %q", want, got)
+ }
+
+ submatches = lockTokenRegex.FindStringSubmatch(string(body))
+ if len(submatches) != 2 {
+ t.Fatal("failed to find locktoken")
+ }
+ lockToken := submatches[1]
+ ifHeader := fmt.Sprintf("<%s> (<%s>)", u, lockToken)
+
+ // Then refresh the lock with a longer timeout
+ req, err = http.NewRequest("LOCK", u, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Header.Set("Depth", "infinity")
+ req.Header.Set("Timeout", "Second-600")
+ req.Header.Set("If", ifHeader)
+ resp, err = client.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ t.Fatalf("expected LOCK refresh to succeed, but got status %d", resp.StatusCode)
+ }
+ body, err = io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ submatches = lockrootRegex.FindStringSubmatch(string(body))
+ if len(submatches) != 2 {
+ t.Fatal("failed to find lockroot after refresh")
+ }
+ want = shared.EscapeForXML(pathTo(remote1, share11, file111))
+ got = submatches[1]
+ if got != want {
+ t.Fatalf("want lockroot after refresh %q, got %q", want, got)
+ }
+
+ submatches = lockTokenRegex.FindStringSubmatch(string(body))
+ if len(submatches) != 2 {
+ t.Fatal("failed to find locktoken after refresh")
+ }
+ if submatches[1] != lockToken {
+ t.Fatalf("on refresh, lock token changed from %q to %q", lockToken, submatches[1])
+ }
+
+ // Then wait past the original timeout, then try to delete without the lock
+ // (should fail)
+ time.Sleep(1 * time.Second)
+ req, err = http.NewRequest("DELETE", u, nil)
+ if err != nil {
+ log.Fatal(err)
+ }
+ resp, err = client.Do(req)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 423 {
+ t.Fatalf("deleting without lock token should fail with 423, but got %d", resp.StatusCode)
+ }
+
+ // Then delete with the lock (should succeed)
+ req, err = http.NewRequest("DELETE", u, nil)
+ if err != nil {
+ log.Fatal(err)
+ }
+ req.Header.Set("If", ifHeader)
+ resp, err = client.Do(req)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 204 {
+ t.Fatalf("deleting with lock token should have succeeded with 204, but got %d", resp.StatusCode)
+ }
+}
+
+func TestUNLOCK(t *testing.T) {
+ s := newSystem(t)
+
+ s.addRemote(remote1)
+ s.addShare(remote1, share11, drive.PermissionReadWrite)
+ s.writeFile("writing file to read/write remote should succeed", remote1, share11, file111, "hello world", true)
+
+ client := &http.Client{
+ Transport: &http.Transport{DisableKeepAlives: true},
+ }
+
+ u := fmt.Sprintf("http://%s/%s/%s/%s/%s",
+ s.local.l.Addr(),
+ url.PathEscape(domain),
+ url.PathEscape(remote1),
+ url.PathEscape(share11),
+ url.PathEscape(file111))
+
+ // Acquire a lock
+ req, err := http.NewRequest("LOCK", u, strings.NewReader(lockBody))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Header.Set("Depth", "infinity")
+ req.Header.Set("Timeout", "Second-600")
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ t.Fatalf("expected LOCK to succeed, but got status %d", resp.StatusCode)
+ }
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ submatches := lockTokenRegex.FindStringSubmatch(string(body))
+ if len(submatches) != 2 {
+ t.Fatal("failed to find locktoken")
+ }
+ lockToken := submatches[1]
+
+ // Release the lock
+ req, err = http.NewRequest("UNLOCK", u, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Header.Set("Lock-Token", fmt.Sprintf("<%s>", lockToken))
+ resp, err = client.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 204 {
+ t.Fatalf("expected UNLOCK to succeed with a 204, but got status %d", resp.StatusCode)
+ }
+
+ // Then delete without the lock (should succeed)
+ req, err = http.NewRequest("DELETE", u, nil)
+ if err != nil {
+ log.Fatal(err)
+ }
+ resp, err = client.Do(req)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 204 {
+ t.Fatalf("deleting without lock should have succeeded with 204, but got %d", resp.StatusCode)
+ }
+}
+
type local struct {
l net.Listener
fs *FileSystemForLocal
@@ -486,3 +692,9 @@ func (a *noopAuthenticator) Clone() gowebdav.Authenticator {
func (a *noopAuthenticator) Close() error {
return nil
}
+
+const lockBody = `
+
+
+
+`
diff --git a/drive/driveimpl/fileserver.go b/drive/driveimpl/fileserver.go
index e9ea7331e8f0e0..0067c1cc7db63a 100644
--- a/drive/driveimpl/fileserver.go
+++ b/drive/driveimpl/fileserver.go
@@ -151,6 +151,9 @@ func (s *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
return
}
+ // WebDAV's locking code compares the lock resources with the request's
+ // host header, set this to empty to avoid mismatches.
+ r.Host = ""
h.ServeHTTP(w, r)
}
diff --git a/drive/driveimpl/shared/xml.go b/drive/driveimpl/shared/xml.go
new file mode 100644
index 00000000000000..79fd0885dd5004
--- /dev/null
+++ b/drive/driveimpl/shared/xml.go
@@ -0,0 +1,16 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package shared
+
+import (
+ "bytes"
+ "encoding/xml"
+)
+
+// EscapeForXML escapes the given string for use in XML text.
+func EscapeForXML(s string) string {
+ result := bytes.NewBuffer(nil)
+ xml.Escape(result, []byte(s))
+ return result.String()
+}