Skip to content

Commit

Permalink
Use errors.Is everywhere.
Browse files Browse the repository at this point in the history
This should cover every single case where we were using == or != on an
err.

There may be other cases to address, but this covers a big one.
  • Loading branch information
Zeph / Liz Loss-Cutler-Hull committed Nov 1, 2023
1 parent a0a4a52 commit 8ee6563
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 22 deletions.
12 changes: 6 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func (c *Client) nextID() uint32 {
func (c *Client) recvVersion() error {
typ, data, err := c.recvPacket(0)
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
return fmt.Errorf("server unexpectedly closed connection: %w", io.ErrUnexpectedEOF)
}

Expand Down Expand Up @@ -368,7 +368,7 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
return nil, unimplementedPacketErr(typ)
}
}
if err == io.EOF {
if errors.Is(err, io.EOF) {
err = nil
}
return attrs, err
Expand Down Expand Up @@ -1238,7 +1238,7 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) {
}

if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
return written, nil // return nil explicitly.
}

Expand Down Expand Up @@ -1435,7 +1435,7 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
}

if packet.err != nil {
if packet.err == io.EOF {
if errors.Is(packet.err, io.EOF) {
return written, nil
}

Expand Down Expand Up @@ -1726,7 +1726,7 @@ func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64
}

if err != nil {
if err != io.EOF {
if !errors.Is(err, io.EOF) {
errCh <- rwErr{off, err}
}
return
Expand Down Expand Up @@ -1878,7 +1878,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
}

if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
return read, nil // return nil explicitly.
}

Expand Down
6 changes: 3 additions & 3 deletions client_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,7 @@ func TestClientReadSimple(t *testing.T) {
defer f2.Close()
stuff := make([]byte, 32)
n, err := f2.Read(stuff)
if err != nil && err != io.EOF {
if err != nil && !errors.Is(err, io.EOF) {
t.Fatalf("err: %v", err)
}
if n != 5 {
Expand Down Expand Up @@ -2152,7 +2152,7 @@ func TestMatch(t *testing.T) {
pattern := tt.pattern
s := tt.s
ok, err := Match(pattern, s)
if ok != tt.match || err != tt.err {
if ok != tt.match || !errors.Is(err, tt.err) {
t.Errorf("Match(%#q, %#q) = %v, %q want %v, %q", pattern, s, ok, errp(err), tt.match, errp(tt.err))
}
}
Expand Down Expand Up @@ -2411,7 +2411,7 @@ func benchmarkRead(b *testing.B, bufsize int, delay time.Duration) {
for offset < size {
n, err := io.ReadFull(f2, buf)
offset += n
if err == io.ErrUnexpectedEOF && offset != size {
if errors.Is(err, io.ErrUnexpectedEOF) && offset != size {
b.Fatalf("read too few bytes! want: %d, got: %d", size, n)
}

Expand Down
3 changes: 2 additions & 1 deletion examples/go-sftp-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package main

import (
"errors"
"flag"
"fmt"
"io"
Expand Down Expand Up @@ -136,7 +137,7 @@ func main() {
if err != nil {
log.Fatal(err)
}
if err := server.Serve(); err == io.EOF {
if err := server.Serve(); errors.Is(err, io.EOF) {
server.Close()
log.Print("sftp client exited session.")
} else if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion examples/request-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package main

import (
"errors"
"flag"
"fmt"
"io"
Expand Down Expand Up @@ -120,7 +121,7 @@ func main() {

root := sftp.InMemHandler()
server := sftp.NewRequestServer(channel, root)
if err := server.Serve(); err == io.EOF {
if err := server.Serve(); errors.Is(err, io.EOF) {
server.Close()
log.Print("sftp client exited session.")
} else if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, e
if _, err := io.ReadFull(r, b[:length]); err != nil {
// ReadFull only returns EOF if it has read no bytes.
// In this case, that means a partial packet, and thus unexpected.
if err == io.EOF {
if errors.Is(err, io.EOF) {
err = io.ErrUnexpectedEOF
}
debug("recv packet %d bytes: err %v", length, err)
Expand Down
2 changes: 1 addition & 1 deletion request-example.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (fs *root) openfile(pathname string, flags uint32) (*memFile, error) {
pflags := newFileOpenFlags(flags)

file, err := fs.fetch(pathname)
if err == os.ErrNotExist {
if errors.Is(err, os.ErrNotExist) {
if !pflags.Creat {
return nil, os.ErrNotExist
}
Expand Down
2 changes: 1 addition & 1 deletion request-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func (rs *RequestServer) Serve() error {
// make sure all open requests are properly closed
// (eg. possible on dropped connections, client crashes, etc.)
for handle, req := range rs.openRequests {
if err == io.EOF {
if errors.Is(err, io.EOF) {
err = io.ErrUnexpectedEOF
}
req.transferError(err)
Expand Down
3 changes: 2 additions & 1 deletion request-server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sftp

import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -240,7 +241,7 @@ func TestRequestJustRead(t *testing.T) {
defer rf.Close()
contents := make([]byte, 5)
n, err := rf.Read(contents)
if err != nil && err != io.EOF {
if err != nil && !errors.Is(err, io.EOF) {
t.Fatalf("err: %v", err)
}
assert.Equal(t, 5, n)
Expand Down
10 changes: 5 additions & 5 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orde

n, err := rd.ReadAt(data, offset)
// only return EOF error if no data left to read
if err != nil && (err != io.EOF || n == 0) {
if err != nil && (!errors.Is(err, io.EOF) || n == 0) {
return statusFromError(pkt.id(), err)
}

Expand Down Expand Up @@ -422,7 +422,7 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o

n, err := rw.ReadAt(data, offset)
// only return EOF error if no data left to read
if err != nil && (err != io.EOF || n == 0) {
if err != nil && (!errors.Is(err, io.EOF) || n == 0) {
return statusFromError(pkt.id(), err)
}

Expand Down Expand Up @@ -507,7 +507,7 @@ func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket {

switch r.Method {
case "List":
if err != nil && (err != io.EOF || n == 0) {
if err != nil && (!errors.Is(err, io.EOF) || n == 0) {
return statusFromError(pkt.id(), err)
}

Expand Down Expand Up @@ -560,7 +560,7 @@ func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket {

switch r.Method {
case "Stat", "Lstat":
if err != nil && err != io.EOF {
if err != nil && !errors.Is(err, io.EOF) {
return statusFromError(pkt.id(), err)
}
if n == 0 {
Expand All @@ -576,7 +576,7 @@ func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket {
info: finfo[0],
}
case "Readlink":
if err != nil && err != io.EOF {
if err != nil && !errors.Is(err, io.EOF) {
return statusFromError(pkt.id(), err)
}
if n == 0 {
Expand Down
4 changes: 2 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ func handlePacket(s *Server, p orderedRequest) error {
err = nil
data := p.getDataSlice(s.pktMgr.alloc, orderID)
n, _err := f.ReadAt(data, int64(p.Offset))
if _err != nil && (_err != io.EOF || n == 0) {
if _err != nil && (!errors.Is(_err, io.EOF) || n == 0) {
err = _err
}
rpkt = &sshFxpDataPacket{
Expand Down Expand Up @@ -354,7 +354,7 @@ func (svr *Server) Serve() error {
pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID())
if err != nil {
// Check whether the connection terminated cleanly in-between packets.
if err == io.EOF {
if errors.Is(err, io.EOF) {
err = nil
}
// we don't care about releasing allocated pages here, the server will quit and the allocator freed
Expand Down

0 comments on commit 8ee6563

Please sign in to comment.