From f3501dc6ba301548dc514108039e66f319748a1a Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 19 Jan 2024 01:23:22 +0000 Subject: [PATCH] address code review --- client.go | 30 ++++++++----- packet.go | 89 ++++++++++++++++++++++++++------------ packet_test.go | 11 +++-- request-attrs.go | 2 +- request-attrs_test.go | 18 ++++++-- server.go | 28 +++++++----- server_integration_test.go | 2 +- server_test.go | 12 ++--- 8 files changed, 127 insertions(+), 65 deletions(-) diff --git a/client.go b/client.go index 1d55aaea..12d105ad 100644 --- a/client.go +++ b/client.go @@ -363,7 +363,10 @@ func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, e filename, data = unmarshalString(data) _, data = unmarshalString(data) // discard longname var attr *FileStat - attr, data = unmarshalAttrs(data) + attr, data, err = unmarshalAttrs(data) + if err != nil { + return nil, err + } if filename == "." || filename == ".." { continue } @@ -434,8 +437,8 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) { if sid != id { return nil, &unexpectedIDErr{id, sid} } - attr, _ := unmarshalAttrs(data) - return fileInfoFromStat(attr, path.Base(p)), nil + attr, _, err := unmarshalAttrs(data) + return fileInfoFromStat(attr, path.Base(p)), err case sshFxpStatus: return nil, normaliseError(unmarshalStatus(id, data)) default: @@ -660,8 +663,8 @@ func (c *Client) stat(path string) (*FileStat, error) { if sid != id { return nil, &unexpectedIDErr{id, sid} } - attr, _ := unmarshalAttrs(data) - return attr, nil + attr, _, err := unmarshalAttrs(data) + return attr, err case sshFxpStatus: return nil, normaliseError(unmarshalStatus(id, data)) default: @@ -684,8 +687,8 @@ func (c *Client) fstat(handle string) (*FileStat, error) { if sid != id { return nil, &unexpectedIDErr{id, sid} } - attr, _ := unmarshalAttrs(data) - return attr, nil + attr, _, err := unmarshalAttrs(data) + return attr, err case sshFxpStatus: return nil, normaliseError(unmarshalStatus(id, data)) default: @@ -974,8 +977,8 @@ func (c *Client) RemoveAll(path string) error { // File represents a remote file. type File struct { - c *Client - path string + c *Client + path string mu sync.RWMutex handle string @@ -992,6 +995,10 @@ func (f *File) Close() error { return os.ErrClosed } + // When `openssh-portable/sftp-server.c` is doing `handle_close`, + // it will unconditionally mark the handle as unused, + // so we need to also unconditionally mark this handle as invalid. + handle := f.handle f.handle = "" @@ -1485,6 +1492,8 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { } } +// Stat returns the FileInfo structure describing file. If there is an +// error. func (f *File) Stat() (os.FileInfo, error) { f.mu.RLock() defer f.mu.RUnlock() @@ -1496,8 +1505,6 @@ func (f *File) Stat() (os.FileInfo, error) { return f.stat() } -// Stat returns the FileInfo structure describing file. If there is an -// error. func (f *File) stat() (os.FileInfo, error) { fs, err := f.c.fstat(f.handle) if err != nil { @@ -2055,7 +2062,6 @@ func (f *File) Sync() error { return os.ErrClosed } - id := f.c.nextID() typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{ ID: id, diff --git a/packet.go b/packet.go index 2fea2bef..f37cd4dc 100644 --- a/packet.go +++ b/packet.go @@ -174,36 +174,69 @@ func unmarshalStringSafe(b []byte) (string, []byte, error) { return string(b[:n]), b[n:], nil } -func unmarshalAttrs(b []byte) (*FileStat, []byte) { - flags, b := unmarshalUint32(b) +func unmarshalAttrs(b []byte) (*FileStat, []byte, error) { + flags, b, err := unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } return unmarshalFileStat(flags, b) } -func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) { +func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte, error) { var fs FileStat + var err error + if flags&sshFileXferAttrSize == sshFileXferAttrSize { - fs.Size, b, _ = unmarshalUint64Safe(b) + fs.Size, b, err = unmarshalUint64Safe(b) + if err != nil { + return nil, b, err + } } if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { - fs.UID, b, _ = unmarshalUint32Safe(b) - fs.GID, b, _ = unmarshalUint32Safe(b) + fs.UID, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } + fs.GID, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } } if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions { - fs.Mode, b, _ = unmarshalUint32Safe(b) + fs.Mode, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } } if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime { - fs.Atime, b, _ = unmarshalUint32Safe(b) - fs.Mtime, b, _ = unmarshalUint32Safe(b) + fs.Atime, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } + fs.Mtime, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } } if flags&sshFileXferAttrExtended == sshFileXferAttrExtended { var count uint32 - count, b, _ = unmarshalUint32Safe(b) + count, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } + ext := make([]StatExtended, count) for i := uint32(0); i < count; i++ { var typ string var data string - typ, b, _ = unmarshalStringSafe(b) - data, b, _ = unmarshalStringSafe(b) + typ, b, err = unmarshalStringSafe(b) + if err != nil { + return nil, b, err + } + data, b, err = unmarshalStringSafe(b) + if err != nil { + return nil, b, err + } ext[i] = StatExtended{ ExtType: typ, ExtData: data, @@ -211,7 +244,7 @@ func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) { } fs.Extended = ext } - return &fs, b + return &fs, b, nil } func unmarshalStatus(id uint32, data []byte) error { @@ -734,15 +767,15 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error { return nil } -func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) *FileStat { +func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) (*FileStat, error) { switch attrs := p.Attrs.(type) { case *FileStat: - return attrs + return attrs, nil case []byte: - fs, _ := unmarshalFileStat(flags, attrs) - return fs + fs, _, err := unmarshalFileStat(flags, attrs) + return fs, err default: - panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs)) + return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs) } } @@ -1030,15 +1063,15 @@ func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error { return nil } -func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) *FileStat { +func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) { switch attrs := p.Attrs.(type) { case *FileStat: - return attrs + return attrs, nil case []byte: - fs, _ := unmarshalFileStat(flags, attrs) - return fs + fs, _, err := unmarshalFileStat(flags, attrs) + return fs, err default: - panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs)) + return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs) } } @@ -1055,15 +1088,15 @@ func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error { return nil } -func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) *FileStat { +func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) { switch attrs := p.Attrs.(type) { case *FileStat: - return attrs + return attrs, nil case []byte: - fs, _ := unmarshalFileStat(flags, attrs) - return fs + fs, _, err := unmarshalFileStat(flags, attrs) + return fs, err default: - panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs)) + return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs) } } diff --git a/packet_test.go b/packet_test.go index 6278ca4f..98455abe 100644 --- a/packet_test.go +++ b/packet_test.go @@ -284,7 +284,10 @@ func TestUnmarshalAttrs(t *testing.T) { } for _, tt := range tests { - got, _ := unmarshalAttrs(tt.b) + got, _, err := unmarshalAttrs(tt.b) + if err != nil { + t.Fatal("unexpected error:", err) + } if !reflect.DeepEqual(got, tt.want) { t.Errorf("unmarshalAttrs(% X):\n- got: %#v\n- want: %#v", tt.b, got, tt.want) } @@ -389,11 +392,11 @@ func TestSendPacket(t *testing.T) { }, { packet: &sshFxpOpenPacket{ - ID: 3, - Path: "/foo", + ID: 3, + Path: "/foo", Pflags: toPflags(os.O_WRONLY | os.O_CREATE | os.O_TRUNC), Flags: sshFileXferAttrPermissions, - Attrs: &FileStat{ + Attrs: &FileStat{ Mode: 0o755, }, }, diff --git a/request-attrs.go b/request-attrs.go index c86539cc..476c5651 100644 --- a/request-attrs.go +++ b/request-attrs.go @@ -52,6 +52,6 @@ func (r *Request) AttrFlags() FileAttrFlags { // Attributes parses file attributes byte blob and return them in a // FileStat object. func (r *Request) Attributes() *FileStat { - fs, _ := unmarshalFileStat(r.Flags, r.Attrs) + fs, _, _ := unmarshalFileStat(r.Flags, r.Attrs) return fs } diff --git a/request-attrs_test.go b/request-attrs_test.go index 658afca0..b1b559b8 100644 --- a/request-attrs_test.go +++ b/request-attrs_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRequestPflags(t *testing.T) { @@ -33,7 +34,8 @@ func TestRequestAttributes(t *testing.T) { at := []byte{} at = marshalUint32(at, 1) at = marshalUint32(at, 2) - testFs, _ := unmarshalFileStat(fl, at) + testFs, _, err := unmarshalFileStat(fl, at) + require.NoError(t, err) assert.Equal(t, fa, *testFs) // Size and Mode fa = FileStat{Mode: 0700, Size: 99} @@ -41,7 +43,8 @@ func TestRequestAttributes(t *testing.T) { at = []byte{} at = marshalUint64(at, 99) at = marshalUint32(at, 0700) - testFs, _ = unmarshalFileStat(fl, at) + testFs, _, err = unmarshalFileStat(fl, at) + require.NoError(t, err) assert.Equal(t, fa, *testFs) // FileMode assert.True(t, testFs.FileMode().IsRegular()) @@ -50,7 +53,16 @@ func TestRequestAttributes(t *testing.T) { } func TestRequestAttributesEmpty(t *testing.T) { - fs, b := unmarshalFileStat(sshFileXferAttrAll, nil) + fs, b, err := unmarshalFileStat(sshFileXferAttrAll, []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // size + 0x00, 0x00, 0x00, 0x00, // mode + 0x00, 0x00, 0x00, 0x00, // mtime + 0x00, 0x00, 0x00, 0x00, // atime + 0x00, 0x00, 0x00, 0x00, // uid + 0x00, 0x00, 0x00, 0x00, // gid + 0x00, 0x00, 0x00, 0x00, // extended_count + }) + require.NoError(t, err) assert.Equal(t, &FileStat{ Extended: []StatExtended{}, }, fs) diff --git a/server.go b/server.go index 6e53e264..16f1cabc 100644 --- a/server.go +++ b/server.go @@ -13,6 +13,7 @@ import ( "strconv" "sync" "syscall" + "time" ) const ( @@ -462,10 +463,13 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket { } mode := os.FileMode(0o644) - // Like OpenSSH, we only handle permissions here, if the file is being created. + // Like OpenSSH, we only handle permissions here, and only when the file is being created. // Otherwise, the permissions are ignored. - if p.Flags & sshFileXferAttrPermissions != 0 { - fs := p.unmarshalFileStat(p.Flags) + if p.Flags&sshFileXferAttrPermissions != 0 { + fs, err := p.unmarshalFileStat(p.Flags) + if err != nil { + return statusFromError(p.ID, err) + } mode = fs.FileMode() & os.ModePerm } @@ -507,9 +511,7 @@ func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket { debug("setstat name %q", path) - fs := p.unmarshalFileStat(p.Flags) - - var err error + fs, err := p.unmarshalFileStat(p.Flags) if (p.Flags & sshFileXferAttrSize) != 0 { if err == nil { @@ -545,9 +547,7 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { debug("fsetstat name %q", path) - fs := p.unmarshalFileStat(p.Flags) - - var err error + fs, err := p.unmarshalFileStat(p.Flags) if (p.Flags & sshFileXferAttrSize) != 0 { if err == nil { @@ -561,7 +561,15 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { } if (p.Flags & sshFileXferAttrACmodTime) != 0 { if err == nil { - err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) + switch f := interface{}(f).(type) { + case interface { + Chtimes(atime, mtime time.Time) error + }: + // future-compatible, if any when *os.File supports Chtimes. + err = f.Chtimes(fs.AccessTime(), fs.ModTime()) + default: + err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) + } } } if (p.Flags & sshFileXferAttrUIDGID) != 0 { diff --git a/server_integration_test.go b/server_integration_test.go index 74a6f8a1..398ea865 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -606,7 +606,7 @@ ls -l /usr/bin/ // words[7] as timestamps on dirs can vary for things like /tmp case 8: // words[8] can either have full path or just the filename - bad = !strings.HasSuffix(opWord, "/" + goWord) + bad = !strings.HasSuffix(opWord, "/"+goWord) default: bad = true } diff --git a/server_test.go b/server_test.go index 110e0dee..4cec3123 100644 --- a/server_test.go +++ b/server_test.go @@ -228,11 +228,11 @@ func TestOpenWithPermissions(t *testing.T) { id2 := client.nextID() typ, data, err := client.sendPacket(ctx, nil, &sshFxpOpenPacket{ - ID: id1, - Path: tmppath, + ID: id1, + Path: tmppath, Pflags: pflags, Flags: sshFileXferAttrPermissions, - Attrs: &FileStat{ + Attrs: &FileStat{ Mode: 0o745, }, }) @@ -259,11 +259,11 @@ func TestOpenWithPermissions(t *testing.T) { // Existing files should not have their permissions changed. typ, data, err = client.sendPacket(ctx, nil, &sshFxpOpenPacket{ - ID: id2, - Path: tmppath, + ID: id2, + Path: tmppath, Pflags: pflags, Flags: sshFileXferAttrPermissions, - Attrs: &FileStat{ + Attrs: &FileStat{ Mode: 0o755, }, })