Skip to content

Commit

Permalink
address code review
Browse files Browse the repository at this point in the history
  • Loading branch information
puellanivis committed Jan 19, 2024
1 parent d1903fb commit f3501dc
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 65 deletions.
30 changes: 18 additions & 12 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,10 @@ func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, e
filename, data = unmarshalString(data)
_, data = unmarshalString(data) // discard longname
var attr *FileStat
attr, data = unmarshalAttrs(data)
attr, data, err = unmarshalAttrs(data)
if err != nil {
return nil, err
}
if filename == "." || filename == ".." {
continue
}
Expand Down Expand Up @@ -434,8 +437,8 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) {
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _ := unmarshalAttrs(data)
return fileInfoFromStat(attr, path.Base(p)), nil
attr, _, err := unmarshalAttrs(data)
return fileInfoFromStat(attr, path.Base(p)), err
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
Expand Down Expand Up @@ -660,8 +663,8 @@ func (c *Client) stat(path string) (*FileStat, error) {
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _ := unmarshalAttrs(data)
return attr, nil
attr, _, err := unmarshalAttrs(data)
return attr, err
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
Expand All @@ -684,8 +687,8 @@ func (c *Client) fstat(handle string) (*FileStat, error) {
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _ := unmarshalAttrs(data)
return attr, nil
attr, _, err := unmarshalAttrs(data)
return attr, err
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
Expand Down Expand Up @@ -974,8 +977,8 @@ func (c *Client) RemoveAll(path string) error {

// File represents a remote file.
type File struct {
c *Client
path string
c *Client
path string

mu sync.RWMutex
handle string
Expand All @@ -992,6 +995,10 @@ func (f *File) Close() error {
return os.ErrClosed
}

// When `openssh-portable/sftp-server.c` is doing `handle_close`,
// it will unconditionally mark the handle as unused,
// so we need to also unconditionally mark this handle as invalid.

handle := f.handle
f.handle = ""

Expand Down Expand Up @@ -1485,6 +1492,8 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
}
}

// Stat returns the FileInfo structure describing file. If there is an
// error.
func (f *File) Stat() (os.FileInfo, error) {
f.mu.RLock()
defer f.mu.RUnlock()
Expand All @@ -1496,8 +1505,6 @@ func (f *File) Stat() (os.FileInfo, error) {
return f.stat()
}

// Stat returns the FileInfo structure describing file. If there is an
// error.
func (f *File) stat() (os.FileInfo, error) {
fs, err := f.c.fstat(f.handle)
if err != nil {
Expand Down Expand Up @@ -2055,7 +2062,6 @@ func (f *File) Sync() error {
return os.ErrClosed
}


id := f.c.nextID()
typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{
ID: id,
Expand Down
89 changes: 61 additions & 28 deletions packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,44 +174,77 @@ func unmarshalStringSafe(b []byte) (string, []byte, error) {
return string(b[:n]), b[n:], nil
}

func unmarshalAttrs(b []byte) (*FileStat, []byte) {
flags, b := unmarshalUint32(b)
func unmarshalAttrs(b []byte) (*FileStat, []byte, error) {
flags, b, err := unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
return unmarshalFileStat(flags, b)
}

func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) {
func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte, error) {
var fs FileStat
var err error

if flags&sshFileXferAttrSize == sshFileXferAttrSize {
fs.Size, b, _ = unmarshalUint64Safe(b)
fs.Size, b, err = unmarshalUint64Safe(b)
if err != nil {
return nil, b, err
}
}
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
fs.UID, b, _ = unmarshalUint32Safe(b)
fs.GID, b, _ = unmarshalUint32Safe(b)
fs.UID, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
fs.GID, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
}
if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions {
fs.Mode, b, _ = unmarshalUint32Safe(b)
fs.Mode, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
}
if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime {
fs.Atime, b, _ = unmarshalUint32Safe(b)
fs.Mtime, b, _ = unmarshalUint32Safe(b)
fs.Atime, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
fs.Mtime, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
}
if flags&sshFileXferAttrExtended == sshFileXferAttrExtended {
var count uint32
count, b, _ = unmarshalUint32Safe(b)
count, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}

ext := make([]StatExtended, count)
for i := uint32(0); i < count; i++ {
var typ string
var data string
typ, b, _ = unmarshalStringSafe(b)
data, b, _ = unmarshalStringSafe(b)
typ, b, err = unmarshalStringSafe(b)
if err != nil {
return nil, b, err
}
data, b, err = unmarshalStringSafe(b)
if err != nil {
return nil, b, err
}
ext[i] = StatExtended{
ExtType: typ,
ExtData: data,
}
}
fs.Extended = ext
}
return &fs, b
return &fs, b, nil
}

func unmarshalStatus(id uint32, data []byte) error {
Expand Down Expand Up @@ -734,15 +767,15 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
return nil
}

func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) *FileStat {
func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs
return attrs, nil
case []byte:
fs, _ := unmarshalFileStat(flags, attrs)
return fs
fs, _, err := unmarshalFileStat(flags, attrs)
return fs, err
default:
panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs))
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
}
}

Expand Down Expand Up @@ -1030,15 +1063,15 @@ func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error {
return nil
}

func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) *FileStat {
func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs
return attrs, nil
case []byte:
fs, _ := unmarshalFileStat(flags, attrs)
return fs
fs, _, err := unmarshalFileStat(flags, attrs)
return fs, err
default:
panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs))
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
}
}

Expand All @@ -1055,15 +1088,15 @@ func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error {
return nil
}

func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) *FileStat {
func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs
return attrs, nil
case []byte:
fs, _ := unmarshalFileStat(flags, attrs)
return fs
fs, _, err := unmarshalFileStat(flags, attrs)
return fs, err
default:
panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs))
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
}
}

Expand Down
11 changes: 7 additions & 4 deletions packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,10 @@ func TestUnmarshalAttrs(t *testing.T) {
}

for _, tt := range tests {
got, _ := unmarshalAttrs(tt.b)
got, _, err := unmarshalAttrs(tt.b)
if err != nil {
t.Fatal("unexpected error:", err)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("unmarshalAttrs(% X):\n- got: %#v\n- want: %#v", tt.b, got, tt.want)
}
Expand Down Expand Up @@ -389,11 +392,11 @@ func TestSendPacket(t *testing.T) {
},
{
packet: &sshFxpOpenPacket{
ID: 3,
Path: "/foo",
ID: 3,
Path: "/foo",
Pflags: toPflags(os.O_WRONLY | os.O_CREATE | os.O_TRUNC),
Flags: sshFileXferAttrPermissions,
Attrs: &FileStat{
Attrs: &FileStat{
Mode: 0o755,
},
},
Expand Down
2 changes: 1 addition & 1 deletion request-attrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ func (r *Request) AttrFlags() FileAttrFlags {
// Attributes parses file attributes byte blob and return them in a
// FileStat object.
func (r *Request) Attributes() *FileStat {
fs, _ := unmarshalFileStat(r.Flags, r.Attrs)
fs, _, _ := unmarshalFileStat(r.Flags, r.Attrs)
return fs
}
18 changes: 15 additions & 3 deletions request-attrs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestRequestPflags(t *testing.T) {
Expand Down Expand Up @@ -33,15 +34,17 @@ func TestRequestAttributes(t *testing.T) {
at := []byte{}
at = marshalUint32(at, 1)
at = marshalUint32(at, 2)
testFs, _ := unmarshalFileStat(fl, at)
testFs, _, err := unmarshalFileStat(fl, at)
require.NoError(t, err)
assert.Equal(t, fa, *testFs)
// Size and Mode
fa = FileStat{Mode: 0700, Size: 99}
fl = uint32(sshFileXferAttrSize | sshFileXferAttrPermissions)
at = []byte{}
at = marshalUint64(at, 99)
at = marshalUint32(at, 0700)
testFs, _ = unmarshalFileStat(fl, at)
testFs, _, err = unmarshalFileStat(fl, at)
require.NoError(t, err)
assert.Equal(t, fa, *testFs)
// FileMode
assert.True(t, testFs.FileMode().IsRegular())
Expand All @@ -50,7 +53,16 @@ func TestRequestAttributes(t *testing.T) {
}

func TestRequestAttributesEmpty(t *testing.T) {
fs, b := unmarshalFileStat(sshFileXferAttrAll, nil)
fs, b, err := unmarshalFileStat(sshFileXferAttrAll, []byte{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // size
0x00, 0x00, 0x00, 0x00, // mode
0x00, 0x00, 0x00, 0x00, // mtime
0x00, 0x00, 0x00, 0x00, // atime
0x00, 0x00, 0x00, 0x00, // uid
0x00, 0x00, 0x00, 0x00, // gid
0x00, 0x00, 0x00, 0x00, // extended_count
})
require.NoError(t, err)
assert.Equal(t, &FileStat{
Extended: []StatExtended{},
}, fs)
Expand Down
28 changes: 18 additions & 10 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strconv"
"sync"
"syscall"
"time"
)

const (
Expand Down Expand Up @@ -462,10 +463,13 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
}

mode := os.FileMode(0o644)
// Like OpenSSH, we only handle permissions here, if the file is being created.
// Like OpenSSH, we only handle permissions here, and only when the file is being created.
// Otherwise, the permissions are ignored.
if p.Flags & sshFileXferAttrPermissions != 0 {
fs := p.unmarshalFileStat(p.Flags)
if p.Flags&sshFileXferAttrPermissions != 0 {
fs, err := p.unmarshalFileStat(p.Flags)
if err != nil {
return statusFromError(p.ID, err)
}
mode = fs.FileMode() & os.ModePerm
}

Expand Down Expand Up @@ -507,9 +511,7 @@ func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket {

debug("setstat name %q", path)

fs := p.unmarshalFileStat(p.Flags)

var err error
fs, err := p.unmarshalFileStat(p.Flags)

if (p.Flags & sshFileXferAttrSize) != 0 {
if err == nil {
Expand Down Expand Up @@ -545,9 +547,7 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket {

debug("fsetstat name %q", path)

fs := p.unmarshalFileStat(p.Flags)

var err error
fs, err := p.unmarshalFileStat(p.Flags)

if (p.Flags & sshFileXferAttrSize) != 0 {
if err == nil {
Expand All @@ -561,7 +561,15 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket {
}
if (p.Flags & sshFileXferAttrACmodTime) != 0 {
if err == nil {
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
switch f := interface{}(f).(type) {
case interface {
Chtimes(atime, mtime time.Time) error
}:
// future-compatible, if any when *os.File supports Chtimes.
err = f.Chtimes(fs.AccessTime(), fs.ModTime())
default:
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
}
}
}
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
Expand Down

0 comments on commit f3501dc

Please sign in to comment.