diff --git a/packet.go b/packet.go index cbaa90e7..bfe6a3c9 100644 --- a/packet.go +++ b/packet.go @@ -823,7 +823,7 @@ func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error { // So, we need: uint32(length) + byte(type) + uint32(id) + uint32(data_length) const dataHeaderLen = 4 + 1 + 4 + 4 -func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte { +func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32, maxTxPacket uint32) []byte { dataLen := p.Len if dataLen > maxTxPacket { dataLen = maxTxPacket diff --git a/request-server.go b/request-server.go index 7a99db64..11047e6b 100644 --- a/request-server.go +++ b/request-server.go @@ -10,7 +10,7 @@ import ( "sync" ) -var maxTxPacket uint32 = 1 << 15 +const defaultMaxTxPacket uint32 = 1 << 15 // Handlers contains the 4 SFTP server request handlers. type Handlers struct { @@ -28,6 +28,7 @@ type RequestServer struct { pktMgr *packetManager startDirectory string + maxTxPacket uint32 mu sync.RWMutex handleCount int @@ -57,6 +58,22 @@ func WithStartDirectory(startDirectory string) RequestServerOption { } } +// WithRSMaxTxPacket sets the maximum size of the payload returned to the client, +// measured in bytes. The default value is 32768 bytes, and this option +// can only be used to increase it. Setting this option to a larger value +// should be safe, because the client decides the size of the requested payload. +// +// The default maximum packet size is 32768 bytes. +func WithRSMaxTxPacket(size uint32) RequestServerOption { + return func(rs *RequestServer) { + if size < defaultMaxTxPacket { + return + } + + rs.maxTxPacket = size + } +} + // NewRequestServer creates/allocates/returns new RequestServer. // Normally there will be one server per user-session. func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer { @@ -73,6 +90,7 @@ func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServ pktMgr: newPktMgr(svrConn), startDirectory: "/", + maxTxPacket: defaultMaxTxPacket, openRequests: make(map[string]*Request), } @@ -260,7 +278,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR Method: "Stat", Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath), } - rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket) } case *sshFxpFsetstatPacket: handle := pkt.getHandle() @@ -272,7 +290,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR Method: "Setstat", Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath), } - rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket) } case *sshFxpExtendedPacketPosixRename: request := &Request{ @@ -280,24 +298,24 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath), Target: cleanPathWithBase(rs.startDirectory, pkt.Newpath), } - rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket) case *sshFxpExtendedPacketStatVFS: request := &Request{ Method: "StatVFS", Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path), } - rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket) case hasHandle: handle := pkt.getHandle() request, ok := rs.getRequest(handle) if !ok { rpkt = statusFromError(pkt.id(), EBADF) } else { - rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket) } case hasPath: request := requestFromPacket(ctx, pkt, rs.startDirectory) - rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket) request.close() default: rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported) diff --git a/request.go b/request.go index cd652cfd..e7c47a9c 100644 --- a/request.go +++ b/request.go @@ -300,14 +300,14 @@ func (r *Request) transferError(err error) { } // called from worker to handle packet/request -func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { +func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket { switch r.Method { case "Get": - return fileget(handlers.FileGet, r, pkt, alloc, orderID) + return fileget(handlers.FileGet, r, pkt, alloc, orderID, maxTxPacket) case "Put": - return fileput(handlers.FilePut, r, pkt, alloc, orderID) + return fileput(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket) case "Open": - return fileputget(handlers.FilePut, r, pkt, alloc, orderID) + return fileputget(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket) case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS": return filecmd(handlers.FileCmd, r, pkt) case "List": @@ -392,13 +392,13 @@ func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket { } // wrap FileReader handler -func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { +func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket { rd := r.getReaderAt() if rd == nil { return statusFromError(pkt.id(), errors.New("unexpected read packet")) } - data, offset, _ := packetData(pkt, alloc, orderID) + data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket) n, err := rd.ReadAt(data, offset) // only return EOF error if no data left to read @@ -414,20 +414,20 @@ func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orde } // wrap FileWriter handler -func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { +func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket { wr := r.getWriterAt() if wr == nil { return statusFromError(pkt.id(), errors.New("unexpected write packet")) } - data, offset, _ := packetData(pkt, alloc, orderID) + data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket) _, err := wr.WriteAt(data, offset) return statusFromError(pkt.id(), err) } // wrap OpenFileWriter handler -func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { +func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket { rw := r.getWriterAtReaderAt() if rw == nil { return statusFromError(pkt.id(), errors.New("unexpected write and read packet")) @@ -435,7 +435,7 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o switch p := pkt.(type) { case *sshFxpReadPacket: - data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset) + data, offset := p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset) n, err := rw.ReadAt(data, offset) // only return EOF error if no data left to read @@ -461,10 +461,10 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o } // file data for additional read/write packets -func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) { +func packetData(p requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) (data []byte, offset int64, length uint32) { switch p := p.(type) { case *sshFxpReadPacket: - return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len + return p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset), p.Len case *sshFxpWritePacket: return p.Data, int64(p.Offset), p.Length } diff --git a/request_test.go b/request_test.go index 92f7c2bf..807833aa 100644 --- a/request_test.go +++ b/request_test.go @@ -149,7 +149,7 @@ func TestRequestGet(t *testing.T) { for i, txt := range []string{"file-", "data."} { pkt := &sshFxpReadPacket{ID: uint32(i), Handle: "a", Offset: uint64(i * 5), Len: 5} - rpkt := request.call(handlers, pkt, nil, 0) + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) dpkt := rpkt.(*sshFxpDataPacket) assert.Equal(t, dpkt.id(), uint32(i)) assert.Equal(t, string(dpkt.Data), txt) @@ -162,7 +162,7 @@ func TestRequestCustomError(t *testing.T) { pkt := fakePacket{myid: 1} cmdErr := errors.New("stat not supported") handlers.returnError(cmdErr) - rpkt := request.call(handlers, pkt, nil, 0) + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) assert.Equal(t, rpkt, statusFromError(pkt.myid, cmdErr)) } @@ -173,11 +173,11 @@ func TestRequestPut(t *testing.T) { request.state.writerAt, _ = handlers.FilePut.Filewrite(request) pkt := &sshFxpWritePacket{ID: 0, Handle: "a", Offset: 0, Length: 5, Data: []byte("file-")} - rpkt := request.call(handlers, pkt, nil, 0) + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) checkOkStatus(t, rpkt) pkt = &sshFxpWritePacket{ID: 1, Handle: "a", Offset: 5, Length: 5, Data: []byte("data.")} - rpkt = request.call(handlers, pkt, nil, 0) + rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) checkOkStatus(t, rpkt) assert.Equal(t, "file-data.", handlers.getOutString()) } @@ -186,11 +186,11 @@ func TestRequestCmdr(t *testing.T) { handlers := newTestHandlers() request := testRequest("Mkdir") pkt := fakePacket{myid: 1} - rpkt := request.call(handlers, pkt, nil, 0) + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) checkOkStatus(t, rpkt) handlers.returnError(errTest) - rpkt = request.call(handlers, pkt, nil, 0) + rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) assert.Equal(t, rpkt, statusFromError(pkt.myid, errTest)) } @@ -198,7 +198,7 @@ func TestRequestInfoStat(t *testing.T) { handlers := newTestHandlers() request := testRequest("Stat") pkt := fakePacket{myid: 1} - rpkt := request.call(handlers, pkt, nil, 0) + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) spkt, ok := rpkt.(*sshFxpStatResponse) assert.True(t, ok) assert.Equal(t, spkt.info.Name(), "request_test.go") @@ -215,13 +215,13 @@ func TestRequestInfoList(t *testing.T) { assert.Equal(t, hpkt.Handle, "1") } pkt = fakePacket{myid: 2} - request.call(handlers, pkt, nil, 0) + request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) } func TestRequestInfoReadlink(t *testing.T) { handlers := newTestHandlers() request := testRequest("Readlink") pkt := fakePacket{myid: 1} - rpkt := request.call(handlers, pkt, nil, 0) + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) npkt, ok := rpkt.(*sshFxpNamePacket) if assert.True(t, ok) { assert.IsType(t, &sshFxpNameAttr{}, npkt.NameAttrs[0]) @@ -234,7 +234,7 @@ func TestOpendirHandleReuse(t *testing.T) { request := testRequest("Stat") request.handle = "1" pkt := fakePacket{myid: 1} - rpkt := request.call(handlers, pkt, nil, 0) + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) assert.IsType(t, &sshFxpStatResponse{}, rpkt) request.Method = "List" @@ -244,6 +244,6 @@ func TestOpendirHandleReuse(t *testing.T) { hpkt := rpkt.(*sshFxpHandlePacket) assert.Equal(t, hpkt.Handle, "1") } - rpkt = request.call(handlers, pkt, nil, 0) + rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) assert.IsType(t, &sshFxpNamePacket{}, rpkt) } diff --git a/server.go b/server.go index acdc30ed..fb474c4f 100644 --- a/server.go +++ b/server.go @@ -34,6 +34,7 @@ type Server struct { openFilesLock sync.RWMutex handleCount int workDir string + maxTxPacket uint32 } func (svr *Server) nextHandle(f *os.File) string { @@ -86,6 +87,7 @@ func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error) debugStream: ioutil.Discard, pktMgr: newPktMgr(svrConn), openFiles: make(map[string]*os.File), + maxTxPacket: defaultMaxTxPacket, } for _, o := range options { @@ -139,6 +141,24 @@ func WithServerWorkingDirectory(workDir string) ServerOption { } } +// WithMaxTxPacket sets the maximum size of the payload returned to the client, +// measured in bytes. The default value is 32768 bytes, and this option +// can only be used to increase it. Setting this option to a larger value +// should be safe, because the client decides the size of the requested payload. +// +// The default maximum packet size is 32768 bytes. +func WithMaxTxPacket(size uint32) ServerOption { + return func(s *Server) error { + if size < defaultMaxTxPacket { + return errors.New("size must be greater than or equal to 32768") + } + + s.maxTxPacket = size + + return nil + } +} + type rxPacket struct { pktType fxp pktBytes []byte @@ -287,7 +307,7 @@ func handlePacket(s *Server, p orderedRequest) error { f, ok := s.getHandle(p.Handle) if ok { err = nil - data := p.getDataSlice(s.pktMgr.alloc, orderID) + data := p.getDataSlice(s.pktMgr.alloc, orderID, s.maxTxPacket) n, _err := f.ReadAt(data, int64(p.Offset)) if _err != nil && (_err != io.EOF || n == 0) { err = _err @@ -513,16 +533,16 @@ func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket { fs, err := p.unmarshalFileStat(p.Flags) - if err == nil && (p.Flags & sshFileXferAttrSize) != 0 { + if err == nil && (p.Flags&sshFileXferAttrSize) != 0 { err = os.Truncate(path, int64(fs.Size)) } - if err == nil && (p.Flags & sshFileXferAttrPermissions) != 0 { + if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 { err = os.Chmod(path, fs.FileMode()) } - if err == nil && (p.Flags & sshFileXferAttrUIDGID) != 0 { + if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 { err = os.Chown(path, int(fs.UID), int(fs.GID)) } - if err == nil && (p.Flags & sshFileXferAttrACmodTime) != 0 { + if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 { err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) } @@ -541,16 +561,16 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { fs, err := p.unmarshalFileStat(p.Flags) - if err == nil && (p.Flags & sshFileXferAttrSize) != 0 { + if err == nil && (p.Flags&sshFileXferAttrSize) != 0 { err = f.Truncate(int64(fs.Size)) } - if err == nil && (p.Flags & sshFileXferAttrPermissions) != 0 { + if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 { err = f.Chmod(fs.FileMode()) } - if err == nil && (p.Flags & sshFileXferAttrUIDGID) != 0 { + if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 { err = f.Chown(int(fs.UID), int(fs.GID)) } - if err == nil && (p.Flags & sshFileXferAttrACmodTime) != 0 { + if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 { type chtimer interface { Chtimes(atime, mtime time.Time) error }