Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ type chanConnection struct {
channel chan []byte
srcAddr, addr *net.UDPAddr
timeout time.Duration
complete chan string
}

func (c *chanConnection) sendTo(data []byte, addr *net.UDPAddr) error {
Expand All @@ -67,9 +66,7 @@ func (c *chanConnection) sendTo(data []byte, addr *net.UDPAddr) error {
func (c *chanConnection) readFrom(buffer []byte) (int, *net.UDPAddr, error) {
select {
case data := <-c.channel:
for i := range data {
buffer[i] = data[i]
}
copy(buffer, data)
return len(data), c.addr, nil
case <-time.After(c.timeout):
return 0, nil, makeError(c.addr.String())
Expand Down Expand Up @@ -98,7 +95,7 @@ func makeError(addr string) net.Error {
timeout: true,
temporary: true,
}
error.error = fmt.Errorf("Channel timeout: %v", addr)
error.error = fmt.Errorf("channel timeout: %v", addr)
return &error
}

Expand Down
10 changes: 5 additions & 5 deletions netascii/netascii_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package netascii

import (
"bytes"
"io/ioutil"
"io"
"strings"
"testing"
"testing/iotest"
Expand All @@ -20,7 +20,7 @@ var basic = map[string]string{
func TestTo(t *testing.T) {
for text, netascii := range basic {
to := ToReader(strings.NewReader(text))
n, _ := ioutil.ReadAll(to)
n, _ := io.ReadAll(to)
if !bytes.Equal(n, []byte(netascii)) {
t.Errorf("%q to netascii: %q != %q", text, n, netascii)
}
Expand All @@ -33,7 +33,7 @@ func TestFrom(t *testing.T) {
b := &bytes.Buffer{}
from := FromWriter(b)
r.WriteTo(from)
n, _ := ioutil.ReadAll(b)
n, _ := io.ReadAll(b)
if string(n) != text {
t.Errorf("%q from netascii: %q != %q", netascii, n, text)
}
Expand Down Expand Up @@ -69,7 +69,7 @@ func TestWriteRead(t *testing.T) {
two := &bytes.Buffer{}
from := FromWriter(two)
one.WriteTo(from)
text2, _ := ioutil.ReadAll(two)
text2, _ := io.ReadAll(two)
if text != string(text2) {
t.Errorf("text mismatch \n%x \n%x", text, text2)
}
Expand All @@ -82,7 +82,7 @@ func TestOneByte(t *testing.T) {
two := &bytes.Buffer{}
from := FromWriter(two)
one.WriteTo(from)
text2, _ := ioutil.ReadAll(two)
text2, _ := io.ReadAll(two)
if text != string(text2) {
t.Errorf("text mismatch \n%x \n%x", text, text2)
}
Expand Down
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func (s *Server) Serve(conn net.PacketConn) error {
// but necessary at this point.
addr := net.ParseIP(host)
if addr == nil {
return fmt.Errorf("Failed to determine IP class of listening address")
return fmt.Errorf("failed to determine IP class of listening address")
}

if conn, ok := s.conn.(*net.UDPConn); ok {
Expand Down
78 changes: 49 additions & 29 deletions single_port.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,66 @@ package tftp

import (
"net"
"time"
)

func (s *Server) singlePortProcessRequests() error {
shuttingDown := false
for {
select {
case <-s.cancel.Done():
s.wg.Wait()
return nil
shuttingDown = true
default:
buf := make([]byte, s.maxBlockLen+4)
cnt, localAddr, srcAddr, maxSz, err := s.getPacket(buf)
if err != nil || cnt == 0 {
if s.hook != nil {
}

if shuttingDown {
// So we not blocked forever waiting for a packet
s.conn.SetReadDeadline(time.Now().Add(time.Second))
}

buf := make([]byte, s.maxBlockLen+4)
cnt, localAddr, srcAddr, maxSz, err := s.getPacket(buf)
if err != nil {
if shuttingDown {
s.wg.Wait()
return nil
}
if s.hook != nil {
s.hook.OnFailure(TransferStats{
SenderAnticipateEnabled: s.sendAEnable,
}, err)
}
continue
}
if cnt == 0 {
continue
}
s.Lock()
if receiverChannel, ok := s.handlers[srcAddr.String()]; ok {
// Packet received for a transfer in progress.
s.Unlock()
select {
case receiverChannel <- buf[:cnt]:
default:
// We don't want to block the main loop if a channel is full
}
} else {
// No existing transfer for given source address. Start a new one.
if shuttingDown {
s.Unlock()
continue
}
lc := make(chan []byte, 1)
s.handlers[srcAddr.String()] = lc
s.Unlock()
go func() {
err := s.handlePacket(localAddr, srcAddr, buf, cnt, maxSz, lc)
if err != nil && s.hook != nil {
s.hook.OnFailure(TransferStats{
SenderAnticipateEnabled: s.sendAEnable,
}, err)
}
continue
}
s.Lock()
if receiverChannel, ok := s.handlers[srcAddr.String()]; ok {
s.Unlock()
select {
case receiverChannel <- buf[:cnt]:
default:
// We don't want to block the main loop if a channel is full
}
} else {
lc := make(chan []byte, 1)
s.handlers[srcAddr.String()] = lc
s.Unlock()
go func() {
err := s.handlePacket(localAddr, srcAddr, buf, cnt, maxSz, lc)
if err != nil && s.hook != nil {
s.hook.OnFailure(TransferStats{
SenderAnticipateEnabled: s.sendAEnable,
}, err)
}
}()
}
}()
}
}
}
Expand Down
Loading