diff --git a/attrs.go b/attrs.go index 758cd4ff..6027c6af 100644 --- a/attrs.go +++ b/attrs.go @@ -56,6 +56,50 @@ type FileStat struct { Extended []StatExtended } +func (fs FileStat) MarshalTo(b []byte, flags FileAttrFlags) []byte { + // attributes variable struct, and also variable per protocol version + // spec version 3 attributes: + // uint32 flags + // uint64 size present only if flag SSH_FILEXFER_ATTR_SIZE + // uint32 uid present only if flag SSH_FILEXFER_ATTR_UIDGID + // uint32 gid present only if flag SSH_FILEXFER_ATTR_UIDGID + // uint32 permissions present only if flag SSH_FILEXFER_ATTR_PERMISSIONS + // uint32 atime present only if flag SSH_FILEXFER_ACMODTIME + // uint32 mtime present only if flag SSH_FILEXFER_ACMODTIME + // uint32 extended_count present only if flag SSH_FILEXFER_ATTR_EXTENDED + // string extended_type + // string extended_data + // ... more extended data (extended_type - extended_data pairs), + // so that number of pairs equals extended_count + if flags.Size { + b = marshalUint64(b, fs.Size) + } + if flags.UidGid { + b = marshalUint32(b, fs.UID) + b = marshalUint32(b, fs.GID) + } + if flags.Permissions { + b = marshalUint32(b, fs.Mode) + } + if flags.Acmodtime { + b = marshalUint32(b, fs.Atime) + b = marshalUint32(b, fs.Mtime) + } + + // NOTE: This is subtle, this logic must not be changed without also changing the login in fileStatFromInfo. + // The rules on how sshFileXferAttrExtended gets set must match the rules on how we generate the packet. + if len(fs.Extended) > 0 { + b = marshalUint32(b, uint32(len(fs.Extended))) + + for _, attr := range fs.Extended { + b = marshalString(b, attr.ExtType) + b = marshalString(b, attr.ExtData) + } + } + + return b +} + // StatExtended contains additional, extended information for a FileStat. type StatExtended struct { ExtType string @@ -109,6 +153,9 @@ func fileStatFromInfo(fi os.FileInfo) (uint32, *FileStat) { fileStat.GID = fiExt.Gid() } + // NOTE: This is subtle, this logic must not be changed without also changing the login in marshalTo. + // The rules on how sshFileXferAttrExtended gets set must match the rules on how we generate the packet. + // // if fi implements FileInfoExtendedData, retrieve extended data from it if fiExt, ok := fi.(FileInfoExtendedData); ok { fileStat.Extended = fiExt.Extended() diff --git a/packet.go b/packet.go index 1232ff1e..076f73d4 100644 --- a/packet.go +++ b/packet.go @@ -54,31 +54,10 @@ func marshalFileInfo(b []byte, fi os.FileInfo) []byte { // so that number of pairs equals extended_count flags, fileStat := fileStatFromInfo(fi) + f := newFileAttrFlags(flags) b = marshalUint32(b, flags) - if flags&sshFileXferAttrSize != 0 { - b = marshalUint64(b, fileStat.Size) - } - if flags&sshFileXferAttrUIDGID != 0 { - b = marshalUint32(b, fileStat.UID) - b = marshalUint32(b, fileStat.GID) - } - if flags&sshFileXferAttrPermissions != 0 { - b = marshalUint32(b, fileStat.Mode) - } - if flags&sshFileXferAttrACmodTime != 0 { - b = marshalUint32(b, fileStat.Atime) - b = marshalUint32(b, fileStat.Mtime) - } - - if flags&sshFileXferAttrExtended != 0 { - b = marshalUint32(b, uint32(len(fileStat.Extended))) - - for _, attr := range fileStat.Extended { - b = marshalString(b, attr.ExtType) - b = marshalString(b, attr.ExtData) - } - } + b = fileStat.MarshalTo(b, f) return b } diff --git a/request-attrs.go b/request-attrs.go index b5c95b4a..4d4fd63b 100644 --- a/request-attrs.go +++ b/request-attrs.go @@ -22,6 +22,29 @@ func newFileOpenFlags(flags uint32) FileOpenFlags { } } +func (fof FileOpenFlags) ForRequest() (flags uint32) { + if fof.Read { + flags |= sshFxfRead + } + if fof.Write { + flags |= sshFxfWrite + } + if fof.Append { + flags |= sshFxfAppend + } + if fof.Creat { + flags |= sshFxfCreat + } + if fof.Trunc { + flags |= sshFxfTrunc + } + if fof.Excl { + flags |= sshFxfExcl + } + + return flags +} + // Pflags converts the bitmap/uint32 from SFTP Open packet pflag values, // into a FileOpenFlags struct with booleans set for flags set in bitmap. func (r *Request) Pflags() FileOpenFlags { @@ -35,6 +58,23 @@ type FileAttrFlags struct { Size, UidGid, Permissions, Acmodtime bool } +func (faf FileAttrFlags) ForRequest() (flags uint32) { + if faf.Size { + flags |= sshFileXferAttrSize + } + if faf.UidGid { + flags |= sshFileXferAttrUIDGID + } + if faf.Permissions { + flags |= sshFileXferAttrPermissions + } + if faf.Acmodtime { + flags |= sshFileXferAttrACmodTime + } + + return flags +} + func newFileAttrFlags(flags uint32) FileAttrFlags { return FileAttrFlags{ Size: (flags & sshFileXferAttrSize) != 0, diff --git a/request-example.go b/request-example.go index 519b3b76..89df58b3 100644 --- a/request-example.go +++ b/request-example.go @@ -37,7 +37,8 @@ func (fs *root) Fileread(r *Request) (io.ReaderAt, error) { return nil, os.ErrInvalid } - return fs.OpenFile(r) + // Needs to be readable by the owner. + return fs.openFileModeCheck(r, 0o400) } func (fs *root) Filewrite(r *Request) (io.WriterAt, error) { @@ -47,10 +48,16 @@ func (fs *root) Filewrite(r *Request) (io.WriterAt, error) { return nil, os.ErrInvalid } - return fs.OpenFile(r) + // Needs to be writable by the owner. + return fs.openFileModeCheck(r, 0o200) } func (fs *root) OpenFile(r *Request) (WriterAtReaderAt, error) { + // Needs to be readable and writable by the owner. + return fs.openFileModeCheck(r, 0o200|0o400) +} + +func (fs *root) openFileModeCheck(r *Request, mode uint32) (WriterAtReaderAt, error) { if fs.mockErr != nil { return nil, fs.mockErr } @@ -59,7 +66,16 @@ func (fs *root) OpenFile(r *Request) (WriterAtReaderAt, error) { fs.mu.Lock() defer fs.mu.Unlock() - return fs.openfile(r.Filepath, r.Flags) + f, err := fs.openfile(r.Filepath, r.Flags) + if err != nil { + return nil, err + } + + if f.mode&mode != mode { + return nil, os.ErrPermission + } + + return f, nil } func (fs *root) putfile(pathname string, file *memFile) error { @@ -72,7 +88,7 @@ func (fs *root) putfile(pathname string, file *memFile) error { return os.ErrInvalid } - if _, err := fs.lfetch(pathname); err != os.ErrNotExist { + if _, err := fs.lfetch(pathname); !errors.Is(err, os.ErrNotExist) { return os.ErrExist } @@ -108,8 +124,10 @@ func (fs *root) openfile(pathname string, flags uint32) (*memFile, error) { link, err = fs.lfetch(pathname) } + // The mode is currently hard coded because this library doesn't parse out the mode at file open time. file := &memFile{ modtime: time.Now(), + mode: 0644, } if err := fs.putfile(pathname, file); err != nil { @@ -151,15 +169,30 @@ func (fs *root) Filecmd(r *Request) error { switch r.Method { case "Setstat": + // Some notes: + // + // openfile will follow symlinks, however as best as I can tell this is the correct POSIX behavior for chmod. + // + // openfile does not currently support opening a directory, and at this time we do not implement directory permissions. + flags := r.AttrFlags() + attrs := r.Attributes() file, err := fs.openfile(r.Filepath, sshFxfWrite) if err != nil { return err } - if r.AttrFlags().Size { - return file.Truncate(int64(r.Attributes().Size)) + if flags.Size { + if err := file.Truncate(int64(attrs.Size)); err != nil { + return err + } + } + if flags.Permissions { + file.chmod(attrs.Mode) + } + // We only have mtime, not atime. + if flags.Acmodtime { + file.modtime = time.Unix(int64(attrs.Mtime), 0) } - return nil case "Rename": @@ -209,7 +242,7 @@ func (fs *root) rename(oldpath, newpath string) error { } target, err := fs.lfetch(newpath) - if err != os.ErrNotExist { + if !errors.Is(err, os.ErrNotExist) { if target == file { // IEEE 1003.1: if oldpath and newpath are the same directory entry, // then return no error, and perform no further action. @@ -507,7 +540,7 @@ func (fs *root) exists(path string) bool { _, err = fs.lfetch(path) - return err != os.ErrNotExist + return !errors.Is(err, os.ErrNotExist) } func (fs *root) fetch(pathname string) (*memFile, error) { @@ -544,6 +577,7 @@ type memFile struct { modtime time.Time symlink string isdir bool + mode uint32 mu sync.RWMutex content []byte @@ -563,13 +597,15 @@ func (f *memFile) Size() int64 { return f.size() } func (f *memFile) Mode() os.FileMode { + // At this time, we do not implement directory modes. if f.isdir { return os.FileMode(0755) | os.ModeDir } + // Under POSIX, symlinks have a fixed mode which can not be changed. if f.symlink != "" { return os.FileMode(0777) | os.ModeSymlink } - return os.FileMode(0644) + return os.FileMode(f.mode) } func (f *memFile) ModTime() time.Time { return f.modtime } func (f *memFile) IsDir() bool { return f.isdir } @@ -645,3 +681,8 @@ func (f *memFile) TransferError(err error) { f.err = err } + +func (f *memFile) chmod(mode uint32) { + const mask = uint32(os.ModePerm | s_ISUID | s_ISGID | s_ISVTX) + f.mode = (f.mode &^ mask) | (mode & mask) +}