Skip to content

Commit

Permalink
Merge pull request #50 from mooneyow/master
Browse files Browse the repository at this point in the history
Add success and failure hooks
  • Loading branch information
pin authored Aug 5, 2019
2 parents 70c6a21 + c32dde3 commit 1fd60d6
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 38 deletions.
1 change: 1 addition & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func (c *Client) RequestTSize(s bool) {
c.tsize = s
}

// Client stores data about a single TFTP client
type Client struct {
addr *net.UDPAddr
timeout time.Duration
Expand Down
34 changes: 28 additions & 6 deletions receiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type receiver struct {
send []byte
receive []byte
addr *net.UDPAddr
filename string
localIP net.IP
tid int
conn connection
Expand All @@ -57,6 +58,8 @@ type receiver struct {
opts options
singlePort bool
maxBlockLen int
hook Hook
startTime time.Time
}

func (r *receiver) WriteTo(w io.Writer) (n int64, err error) {
Expand Down Expand Up @@ -217,7 +220,12 @@ func (r *receiver) terminate() error {
if r.conn == nil {
return nil
}
defer r.conn.close()
defer func() {
if r.hook != nil {
r.hook.OnSuccess(r.buildTransferStats())
}
r.conn.close()
}()
binary.BigEndian.PutUint16(r.send[2:4], r.block)
if r.dally {
for i := 0; i < 3; i++ {
Expand All @@ -227,19 +235,33 @@ func (r *receiver) terminate() error {
}
}
return fmt.Errorf("dallying termination failed")
} else {
err := r.conn.sendTo(r.send[:4], r.addr)
if err != nil {
return err
}
}
err := r.conn.sendTo(r.send[:4], r.addr)
if err != nil {
return err
}
return nil
}

func (r *receiver) buildTransferStats() TransferStats {
return TransferStats{
RemoteAddr: r.addr.IP,
Filename: r.filename,
Tid: r.tid,
TotalBlocks: r.block,
Mode: r.mode,
Opts: r.opts,
Duration: time.Now().Sub(r.startTime),
}
}

func (r *receiver) abort(err error) error {
if r.conn == nil {
return nil
}
if r.hook != nil {
r.hook.OnFailure(r.buildTransferStats(), err)
}
n := packERROR(r.send, 1, err.Error())
err = r.conn.sendTo(r.send[:n], r.addr)
if err != nil {
Expand Down
25 changes: 25 additions & 0 deletions sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type OutgoingTransfer interface {
type sender struct {
conn connection
addr *net.UDPAddr
filename string
localIP net.IP
tid int
send []byte
Expand All @@ -45,6 +46,8 @@ type sender struct {
maxBlockLen int
mode string
opts options
hook Hook
startTime time.Time
}

func (s *sender) RemoteAddr() net.UDPAddr { return *s.addr }
Expand Down Expand Up @@ -107,6 +110,9 @@ func (s *sender) ReadFrom(r io.Reader) (n int64, err error) {
s.abort(err)
return n, err
}
if s.hook != nil {
s.hook.OnSuccess(s.buildTransferStats())
}
s.conn.close()
return n, nil
}
Expand All @@ -120,6 +126,9 @@ func (s *sender) ReadFrom(r io.Reader) (n int64, err error) {
return n, err
}
if l < len(s.send)-4 {
if s.hook != nil {
s.hook.OnSuccess(s.buildTransferStats())
}
s.conn.close()
return n, nil
}
Expand Down Expand Up @@ -243,10 +252,26 @@ func (s *sender) sendDatagram(l int) (*net.UDPAddr, error) {
}
}

func (s *sender) buildTransferStats() TransferStats {
return TransferStats{
RemoteAddr: s.addr.IP,
Filename: s.filename,
Tid: s.tid,
SenderAnticipateEnabled: s.sendA.enabled,
TotalBlocks: s.block,
Mode: s.mode,
Opts: s.opts,
Duration: time.Now().Sub(s.startTime),
}
}

func (s *sender) abort(err error) error {
if s.conn == nil {
return nil
}
if s.hook != nil {
s.hook.OnFailure(s.buildTransferStats(), err)
}
n := packERROR(s.send, 1, err.Error())
err = s.conn.sendTo(s.send[:n], s.addr)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion sender_anticipate.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (s *sender) sendDatagramAnticipate() (*net.UDPAddr, error) {
if err1 != nil {
return nil, err1
}
var err error = nil
var err error
ksz := uint(len(s.sendA.sends))
knum := s.sendA.num
if knum > ksz {
Expand Down
41 changes: 37 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func NewServer(readHandler func(filename string, rf io.ReaderFrom) error,
timeout: defaultTimeout,
retries: defaultRetries,
runGC: make(chan []string),
gcInterval: 1 * time.Minute,
gcThreshold: 100,
packetReadTimeout: 100 * time.Millisecond,
readHandler: readHandler,
writeHandler: writeHandler,
Expand All @@ -40,9 +40,11 @@ type RequestPacketInfo interface {
LocalIP() net.IP
}

// Server is an instance of a TFTP server
type Server struct {
readHandler func(filename string, rf io.ReaderFrom) error
writeHandler func(filename string, wt io.WriterTo) error
hook Hook
backoff backoffFunc
conn *net.UDPConn
conn6 *ipv6.PacketConn
Expand All @@ -60,10 +62,28 @@ type Server struct {
handlers map[string]chan []byte
runGC chan []string
gcCollect chan string
gcInterval time.Duration
gcThreshold int
packetReadTimeout time.Duration
}

// TransferStats contains details about a single TFTP transfer
type TransferStats struct {
RemoteAddr net.IP
Filename string
Tid int
SenderAnticipateEnabled bool
TotalBlocks uint16
Mode string
Opts options
Duration time.Duration
}

// Hook is an interface used to provide the server with success and failure hooks
type Hook interface {
OnSuccess(stats TransferStats)
OnFailure(stats TransferStats, err error)
}

// SetAnticipate provides an experimental feature in which when a packets
// is requested the server will keep sending a number of packets before
// checking whether an ack has been received. It improves tftp downloading
Expand All @@ -83,6 +103,11 @@ func (s *Server) SetAnticipate(winsz uint) {
}
}

// SetHook sets the Hook for success and failure of transfers
func (s *Server) SetHook(hook Hook) {
s.hook = hook
}

// EnableSinglePort enables an experimental mode where the server will
// serve all connections on port 69 only. There will be no random TIDs
// on the server side.
Expand Down Expand Up @@ -203,8 +228,10 @@ func (s *Server) Serve(conn *net.UDPConn) error {
} else {
err = s.processRequest()
}
if err != nil {
// TODO: add logging handler
if err != nil && s.hook != nil {
s.hook.OnFailure(TransferStats{
SenderAnticipateEnabled: s.sendAEnable,
}, err)
}
}
}
Expand Down Expand Up @@ -304,6 +331,9 @@ func (s *Server) handlePacket(localAddr net.IP, remoteAddr *net.UDPAddr, buffer
mode: mode,
opts: opts,
maxBlockLen: maxBlockLen,
hook: s.hook,
filename: filename,
startTime: time.Now(),
}
if s.singlePort {
wt.conn = &chanConnection{
Expand Down Expand Up @@ -353,6 +383,9 @@ func (s *Server) handlePacket(localAddr net.IP, remoteAddr *net.UDPAddr, buffer
mode: mode,
opts: opts,
maxBlockLen: maxBlockLen,
hook: s.hook,
filename: filename,
startTime: time.Now(),
}
if s.singlePort {
rf.conn = &chanConnection{
Expand Down
40 changes: 28 additions & 12 deletions single_port.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package tftp

import (
"net"
"time"
)

func (s *Server) singlePortProcessRequests() error {
Expand All @@ -18,7 +17,15 @@ func (s *Server) singlePortProcessRequests() error {
// We've received a new connection on the same IP+Port tuple
// as a previous connection before garbage collection has occured
s.handlers[srcAddr.String()] = make(chan []byte)
go s.handlePacket(localAddr, srcAddr.(*net.UDPAddr), buf, cnt, blockLength, s.handlers[srcAddr.String()])
go func(localAddr net.IP, remoteAddr *net.UDPAddr, buffer []byte, n, maxBlockLen int, listener chan []byte) {
err := s.handlePacket(localAddr, remoteAddr, buffer, n, maxBlockLen, listener)
if err != nil && s.hook != nil {
s.hook.OnFailure(TransferStats{
SenderAnticipateEnabled: s.sendAEnable,
}, err)
}

}(localAddr, srcAddr.(*net.UDPAddr), buf, cnt, blockLength, s.handlers[srcAddr.String()])
s.singlePortProcessRequests()
}
}()
Expand All @@ -29,13 +36,17 @@ func (s *Server) singlePortProcessRequests() error {
return nil
case handlersToFree := <-s.runGC:
for _, handler := range handlersToFree {
s.handlers[handler] = nil
delete(s.handlers, handler)
}
default:
buf = s.bufPool.Get().([]byte)
cnt, localAddr, srcAddr, err = s.getPacket(buf)
if err != nil || cnt == 0 {
// TODO: add logging handler
if s.hook != nil {
s.hook.OnFailure(TransferStats{
SenderAnticipateEnabled: s.sendAEnable,
}, err)
}
s.bufPool.Put(buf)
continue
}
Expand All @@ -47,17 +58,22 @@ func (s *Server) singlePortProcessRequests() error {
}
} else {
s.handlers[srcAddr.String()] = make(chan []byte, datagramLength)
go s.handlePacket(localAddr, srcAddr.(*net.UDPAddr), buf, cnt, blockLength, s.handlers[srcAddr.String()])
go func(localAddr net.IP, remoteAddr *net.UDPAddr, buffer []byte, n, maxBlockLen int, listener chan []byte) {
err := s.handlePacket(localAddr, remoteAddr, buffer, n, maxBlockLen, listener)
if err != nil && s.hook != nil {
s.hook.OnFailure(TransferStats{
SenderAnticipateEnabled: s.sendAEnable,
}, err)
}

}(localAddr, srcAddr.(*net.UDPAddr), buf, cnt, blockLength, s.handlers[srcAddr.String()])
}
}
}
}

func (s *Server) getPacket(buf []byte) (int, net.IP, *net.UDPAddr, error) {
if s.conn6 != nil {
// TODO: investigate why deadline is necessary
// ReadFrom seems to behave badly without it
s.conn6.SetReadDeadline(time.Now().Add(s.packetReadTimeout))
cnt, control, srcAddr, err := s.conn6.ReadFrom(buf)
if err != nil || cnt == 0 {
return 0, nil, nil, err
Expand All @@ -68,7 +84,6 @@ func (s *Server) getPacket(buf []byte) (int, net.IP, *net.UDPAddr, error) {
}
return cnt, localAddr, srcAddr.(*net.UDPAddr), nil
} else if s.conn4 != nil {
s.conn4.SetReadDeadline(time.Now().Add(s.packetReadTimeout))
cnt, control, srcAddr, err := s.conn4.ReadFrom(buf)
if err != nil || cnt == 0 {
return 0, nil, nil, err
Expand All @@ -95,9 +110,10 @@ func (s *Server) internalGC() {
select {
case newHandler := <-s.gcCollect:
completedHandlers = append(completedHandlers, newHandler)
case <-time.After(s.gcInterval):
s.runGC <- completedHandlers
completedHandlers = nil
if len(completedHandlers) > s.gcThreshold {
s.runGC <- completedHandlers
completedHandlers = nil
}
}
}
}
4 changes: 2 additions & 2 deletions single_port_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ func TestZeroLengthSinglePort(t *testing.T) {
testSendReceive(t, c, 0)
}

func TestSendReveiveSinglePort(t *testing.T) {
func TestSendReceiveSinglePort(t *testing.T) {
s, c := makeTestServer(true)
defer s.Shutdown()
for i := 600; i < 1000; i++ {
testSendReceive(t, c, 5000+int64(i))
}
}

func TestSendReveiveSinglePortWithBlockSize(t *testing.T) {
func TestSendReceiveSinglePortWithBlockSize(t *testing.T) {
s, c := makeTestServer(true)
defer s.Shutdown()
for i := 600; i < 1000; i++ {
Expand Down
2 changes: 1 addition & 1 deletion tftp_anticipate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
func TestAnticipateWindow900(t *testing.T) {
s, c := makeTestServerAnticipateWindow()
defer s.Shutdown()
for i := 600; i < 4000; i += 1 {
for i := 600; i < 4000; i++ {
c.blksize = i
testSendReceive(t, c, 9000+int64(i))
}
Expand Down
Loading

0 comments on commit 1fd60d6

Please sign in to comment.