diff --git a/packet.go b/packet.go index 1a1a87d7..d89ad997 100644 --- a/packet.go +++ b/packet.go @@ -1247,7 +1247,7 @@ func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error { } func (p *sshFxpExtendedPacketPosixRename) respond(s *Server) responsePacket { - err := os.Rename(toLocalPath(p.Oldpath), toLocalPath(p.Newpath)) + err := os.Rename(s.toLocalPath(p.Oldpath), s.toLocalPath(p.Newpath)) return statusFromError(p.ID, err) } @@ -1276,6 +1276,6 @@ func (p *sshFxpExtendedPacketHardlink) UnmarshalBinary(b []byte) error { } func (p *sshFxpExtendedPacketHardlink) respond(s *Server) responsePacket { - err := os.Link(toLocalPath(p.Oldpath), toLocalPath(p.Newpath)) + err := os.Link(s.toLocalPath(p.Oldpath), s.toLocalPath(p.Newpath)) return statusFromError(p.ID, err) } diff --git a/request-plan9.go b/request-plan9.go index 2444da59..0e3b4836 100644 --- a/request-plan9.go +++ b/request-plan9.go @@ -3,8 +3,6 @@ package sftp import ( - "path" - "path/filepath" "syscall" ) @@ -15,20 +13,3 @@ func fakeFileInfoSys() interface{} { func testOsSys(sys interface{}) error { return nil } - -func toLocalPath(p string) string { - lp := filepath.FromSlash(p) - - if path.IsAbs(p) { - tmp := lp[1:] - - if filepath.IsAbs(tmp) { - // If the FromSlash without any starting slashes is absolute, - // then we have a filepath encoded with a prefix '/'. - // e.g. "/#s/boot" to "#s/boot" - return tmp - } - } - - return lp -} diff --git a/request-unix.go b/request-unix.go index 50b08a38..d30b2569 100644 --- a/request-unix.go +++ b/request-unix.go @@ -21,7 +21,3 @@ func testOsSys(sys interface{}) error { } return nil } - -func toLocalPath(p string) string { - return p -} diff --git a/request_windows.go b/request_windows.go index 1f6d3df1..bd1d6864 100644 --- a/request_windows.go +++ b/request_windows.go @@ -1,8 +1,6 @@ package sftp import ( - "path" - "path/filepath" "syscall" ) @@ -13,32 +11,3 @@ func fakeFileInfoSys() interface{} { func testOsSys(sys interface{}) error { return nil } - -func toLocalPath(p string) string { - lp := filepath.FromSlash(p) - - if path.IsAbs(p) { - tmp := lp - for len(tmp) > 0 && tmp[0] == '\\' { - tmp = tmp[1:] - } - - if filepath.IsAbs(tmp) { - // If the FromSlash without any starting slashes is absolute, - // then we have a filepath encoded with a prefix '/'. - // e.g. "/C:/Windows" to "C:\\Windows" - return tmp - } - - tmp += "\\" - - if filepath.IsAbs(tmp) { - // If the FromSlash without any starting slashes but with extra end slash is absolute, - // then we have a filepath encoded with a prefix '/' and a dropped '/' at the end. - // e.g. "/C:" to "C:\\" - return tmp - } - } - - return lp -} diff --git a/server.go b/server.go index 529052b4..c0665ed9 100644 --- a/server.go +++ b/server.go @@ -33,6 +33,7 @@ type Server struct { openFiles map[string]*os.File openFilesLock sync.RWMutex handleCount int + workDir string } func (svr *Server) nextHandle(f *os.File) string { @@ -128,6 +129,16 @@ func WithAllocator() ServerOption { } } +// WithServerWorkingDirectory sets a working directory to use as base +// for relative paths. +// If unset the default is current working directory (os.Getwd). +func WithServerWorkingDirectory(workDir string) ServerOption { + return func(s *Server) error { + s.workDir = cleanPath(workDir) + return nil + } +} + type rxPacket struct { pktType fxp pktBytes []byte @@ -174,7 +185,7 @@ func handlePacket(s *Server, p orderedRequest) error { } case *sshFxpStatPacket: // stat the requested file - info, err := os.Stat(toLocalPath(p.Path)) + info, err := os.Stat(s.toLocalPath(p.Path)) rpkt = &sshFxpStatResponse{ ID: p.ID, info: info, @@ -184,7 +195,7 @@ func handlePacket(s *Server, p orderedRequest) error { } case *sshFxpLstatPacket: // stat the requested file - info, err := os.Lstat(toLocalPath(p.Path)) + info, err := os.Lstat(s.toLocalPath(p.Path)) rpkt = &sshFxpStatResponse{ ID: p.ID, info: info, @@ -208,24 +219,24 @@ func handlePacket(s *Server, p orderedRequest) error { } case *sshFxpMkdirPacket: // TODO FIXME: ignore flags field - err := os.Mkdir(toLocalPath(p.Path), 0755) + err := os.Mkdir(s.toLocalPath(p.Path), 0o755) rpkt = statusFromError(p.ID, err) case *sshFxpRmdirPacket: - err := os.Remove(toLocalPath(p.Path)) + err := os.Remove(s.toLocalPath(p.Path)) rpkt = statusFromError(p.ID, err) case *sshFxpRemovePacket: - err := os.Remove(toLocalPath(p.Filename)) + err := os.Remove(s.toLocalPath(p.Filename)) rpkt = statusFromError(p.ID, err) case *sshFxpRenamePacket: - err := os.Rename(toLocalPath(p.Oldpath), toLocalPath(p.Newpath)) + err := os.Rename(s.toLocalPath(p.Oldpath), s.toLocalPath(p.Newpath)) rpkt = statusFromError(p.ID, err) case *sshFxpSymlinkPacket: - err := os.Symlink(toLocalPath(p.Targetpath), toLocalPath(p.Linkpath)) + err := os.Symlink(s.toLocalPath(p.Targetpath), s.toLocalPath(p.Linkpath)) rpkt = statusFromError(p.ID, err) case *sshFxpClosePacket: rpkt = statusFromError(p.ID, s.closeHandle(p.Handle)) case *sshFxpReadlinkPacket: - f, err := os.Readlink(toLocalPath(p.Path)) + f, err := os.Readlink(s.toLocalPath(p.Path)) rpkt = &sshFxpNamePacket{ ID: p.ID, NameAttrs: []*sshFxpNameAttr{ @@ -240,7 +251,7 @@ func handlePacket(s *Server, p orderedRequest) error { rpkt = statusFromError(p.ID, err) } case *sshFxpRealpathPacket: - f, err := filepath.Abs(toLocalPath(p.Path)) + f, err := filepath.Abs(s.toLocalPath(p.Path)) f = cleanPath(f) rpkt = &sshFxpNamePacket{ ID: p.ID, @@ -256,13 +267,14 @@ func handlePacket(s *Server, p orderedRequest) error { rpkt = statusFromError(p.ID, err) } case *sshFxpOpendirPacket: - p.Path = toLocalPath(p.Path) + lp := s.toLocalPath(p.Path) - if stat, err := os.Stat(p.Path); err != nil { + if stat, err := os.Stat(lp); err != nil { rpkt = statusFromError(p.ID, err) } else if !stat.IsDir() { rpkt = statusFromError(p.ID, &os.PathError{ - Path: p.Path, Err: syscall.ENOTDIR}) + Path: lp, Err: syscall.ENOTDIR, + }) } else { rpkt = (&sshFxpOpenPacket{ ID: p.ID, @@ -446,7 +458,7 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket { osFlags |= os.O_EXCL } - f, err := os.OpenFile(toLocalPath(p.Path), osFlags, 0644) + f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, 0o644) if err != nil { return statusFromError(p.ID, err) } @@ -484,7 +496,7 @@ func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket { b := p.Attrs.([]byte) var err error - p.Path = toLocalPath(p.Path) + p.Path = svr.toLocalPath(p.Path) debug("setstat name \"%s\"", p.Path) if (p.Flags & sshFileXferAttrSize) != 0 { diff --git a/server_nowindows_test.go b/server_nowindows_test.go new file mode 100644 index 00000000..d38dbba3 --- /dev/null +++ b/server_nowindows_test.go @@ -0,0 +1,87 @@ +//go:build !windows +// +build !windows + +package sftp + +import ( + "testing" +) + +func TestServer_toLocalPath(t *testing.T) { + tests := []struct { + name string + withWorkDir string + p string + want string + }{ + { + name: "empty path with no workdir", + p: "", + want: "", + }, + { + name: "relative path with no workdir", + p: "file", + want: "file", + }, + { + name: "absolute path with no workdir", + p: "/file", + want: "/file", + }, + { + name: "workdir and empty path", + withWorkDir: "/home/user", + p: "", + want: "/home/user", + }, + { + name: "workdir and relative path", + withWorkDir: "/home/user", + p: "file", + want: "/home/user/file", + }, + { + name: "workdir and relative path with .", + withWorkDir: "/home/user", + p: ".", + want: "/home/user", + }, + { + name: "workdir and relative path with . and file", + withWorkDir: "/home/user", + p: "./file", + want: "/home/user/file", + }, + { + name: "workdir and absolute path", + withWorkDir: "/home/user", + p: "/file", + want: "/file", + }, + { + name: "workdir and non-unixy path prefixes workdir", + withWorkDir: "/home/user", + p: "C:\\file", + // This may look like a bug but it is the result of passing + // invalid input (a non-unixy path) to the server. + want: "/home/user/C:\\file", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // We don't need to initialize the Server further to test + // toLocalPath behavior. + s := &Server{} + if tt.withWorkDir != "" { + if err := WithServerWorkingDirectory(tt.withWorkDir)(s); err != nil { + t.Fatal(err) + } + } + + if got := s.toLocalPath(tt.p); got != tt.want { + t.Errorf("Server.toLocalPath() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/server_plan9.go b/server_plan9.go new file mode 100644 index 00000000..4e8ed067 --- /dev/null +++ b/server_plan9.go @@ -0,0 +1,27 @@ +package sftp + +import ( + "path" + "path/filepath" +) + +func (s *Server) toLocalPath(p string) string { + if s.workDir != "" && !path.IsAbs(p) { + p = path.Join(s.workDir, p) + } + + lp := filepath.FromSlash(p) + + if path.IsAbs(p) { + tmp := lp[1:] + + if filepath.IsAbs(tmp) { + // If the FromSlash without any starting slashes is absolute, + // then we have a filepath encoded with a prefix '/'. + // e.g. "/#s/boot" to "#s/boot" + return tmp + } + } + + return lp +} diff --git a/server_unix.go b/server_unix.go new file mode 100644 index 00000000..495b397c --- /dev/null +++ b/server_unix.go @@ -0,0 +1,16 @@ +//go:build !windows && !plan9 +// +build !windows,!plan9 + +package sftp + +import ( + "path" +) + +func (s *Server) toLocalPath(p string) string { + if s.workDir != "" && !path.IsAbs(p) { + p = path.Join(s.workDir, p) + } + + return p +} diff --git a/server_windows.go b/server_windows.go new file mode 100644 index 00000000..b35be730 --- /dev/null +++ b/server_windows.go @@ -0,0 +1,39 @@ +package sftp + +import ( + "path" + "path/filepath" +) + +func (s *Server) toLocalPath(p string) string { + if s.workDir != "" && !path.IsAbs(p) { + p = path.Join(s.workDir, p) + } + + lp := filepath.FromSlash(p) + + if path.IsAbs(p) { + tmp := lp + for len(tmp) > 0 && tmp[0] == '\\' { + tmp = tmp[1:] + } + + if filepath.IsAbs(tmp) { + // If the FromSlash without any starting slashes is absolute, + // then we have a filepath encoded with a prefix '/'. + // e.g. "/C:/Windows" to "C:\\Windows" + return tmp + } + + tmp += "\\" + + if filepath.IsAbs(tmp) { + // If the FromSlash without any starting slashes but with extra end slash is absolute, + // then we have a filepath encoded with a prefix '/' and a dropped '/' at the end. + // e.g. "/C:" to "C:\\" + return tmp + } + } + + return lp +} diff --git a/server_windows_test.go b/server_windows_test.go new file mode 100644 index 00000000..ca9ed027 --- /dev/null +++ b/server_windows_test.go @@ -0,0 +1,84 @@ +package sftp + +import ( + "testing" +) + +func TestServer_toLocalPath(t *testing.T) { + tests := []struct { + name string + withWorkDir string + p string + want string + }{ + { + name: "empty path with no workdir", + p: "", + want: "", + }, + { + name: "relative path with no workdir", + p: "file", + want: "file", + }, + { + name: "absolute path with no workdir", + p: "/file", + want: "\\file", + }, + { + name: "workdir and empty path", + withWorkDir: "C:\\Users\\User", + p: "", + want: "C:\\Users\\User", + }, + { + name: "workdir and relative path", + withWorkDir: "C:\\Users\\User", + p: "file", + want: "C:\\Users\\User\\file", + }, + { + name: "workdir and relative path with .", + withWorkDir: "C:\\Users\\User", + p: ".", + want: "C:\\Users\\User", + }, + { + name: "workdir and relative path with . and file", + withWorkDir: "C:\\Users\\User", + p: "./file", + want: "C:\\Users\\User\\file", + }, + { + name: "workdir and absolute path", + withWorkDir: "C:\\Users\\User", + p: "/C:/file", + want: "C:\\file", + }, + { + name: "workdir and non-unixy path prefixes workdir", + withWorkDir: "C:\\Users\\User", + p: "C:\\file", + // This may look like a bug but it is the result of passing + // invalid input (a non-unixy path) to the server. + want: "C:\\Users\\User\\C:\\file", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // We don't need to initialize the Server further to test + // toLocalPath behavior. + s := &Server{} + if tt.withWorkDir != "" { + if err := WithServerWorkingDirectory(tt.withWorkDir)(s); err != nil { + t.Fatal(err) + } + } + + if got := s.toLocalPath(tt.p); got != tt.want { + t.Errorf("Server.toLocalPath() = %q, want %q", got, tt.want) + } + }) + } +}