Skip to content

Commit

Permalink
Merge pull request #17 from puellanivis/issue/avoid-truncated-udp-reads
Browse files Browse the repository at this point in the history
Avoid truncated udp reads
  • Loading branch information
puellanivis committed May 8, 2020
2 parents a0b7ad0 + 8eaeac5 commit 01e1625
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 54 deletions.
59 changes: 42 additions & 17 deletions lib/files/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@ import (
"time"
)

const defaultBufferSize = 32 * 1024
const defaultBufferSize = 64 * 1024

// ErrWatchdogExpired is returned by files.Copy, if the watchdog time expires during a read.
var ErrWatchdogExpired error = watchdogExpiredError{}

type watchdogExpiredError struct{}

func (watchdogExpiredError) Error() string { return "watchdog expired" }
func (watchdogExpiredError) Timeout() bool { return true }
func (watchdogExpiredError) Temporary() bool { return true }

// Copy is a context aware version of io.Copy.
// Do not use to Discard a reader, as a canceled context would stop the read, and it would not be fully discarded.
Expand All @@ -27,7 +36,7 @@ func Copy(ctx context.Context, dst io.Writer, src io.Reader, opts ...CopyOption)
// we allocate a buffer to use as a temporary buffer, rather than alloc new every time.
c.buffer = make([]byte, defaultBufferSize)
}
l := int64(len(c.buffer))
buflen := int64(len(c.buffer))

var keepingMetrics bool

Expand Down Expand Up @@ -62,46 +71,62 @@ func Copy(ctx context.Context, dst io.Writer, src io.Reader, opts ...CopyOption)
bwWindow = make([]bwSnippet, c.bwCount)
}

// Prevent an accidental write outside of returning from this function.
ctx, cancel := context.WithCancel(ctx)
defer cancel()

w := &deadlineWriter{
ctx: ctx,
w: dst,
}

r := &fuzzyLimitedReader{
R: src,
N: buflen,
}

t := time.NewTimer(c.runningTimeout)
if c.runningTimeout <= 0 {
if !t.Stop() {
<-t.C
}
}

start := time.Now()

var bwAccum int64
last := start
next := last.Add(c.bwInterval)

for {
done := make(chan struct{})

ctx := ctx // shadow context intentionally, we might set a timeout later
cancel := func() {} // noop cancel
r.N = buflen // reset fuzzyLimitedReader

if c.runningTimeout > 0 {
ctx, cancel = context.WithTimeout(ctx, c.runningTimeout)
}

w := &deadlineWriter{
ctx: ctx,
w: dst,
if !t.Stop() {
<-t.C
}
t.Reset(c.runningTimeout)
}
r := io.LimitReader(src, l)

var n int64

done := make(chan struct{})
go func() {
defer close(done)

n, err = io.CopyBuffer(w, r, c.buffer)

if n < l && err == nil {
if n < buflen && err == nil {
err = io.EOF
}
}()

select {
case <-done:
cancel()

case <-t.C:
return written, ErrWatchdogExpired

case <-ctx.Done():
cancel()
return written, ctx.Err()
}

Expand Down
25 changes: 25 additions & 0 deletions lib/files/fuzzy_limited_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package files

import (
"io"
)

// fuzzyLimitedReader reads at least N-bytes from the underlying reader.
// It does not ensure that it reads only or at-most N-bytes.
// Each call to Read updates N to reflect the new amount remaining.
// Read returns EOF when N <= 0 or when the underlying R returns EOF.
type fuzzyLimitedReader struct {
R io.Reader // underlying reader
N int64 // stop reading after at least this much
}

func (r *fuzzyLimitedReader) Read(b []byte) (n int, err error) {
if r.N <= 0 {
return 0, io.EOF
}

n, err = r.R.Read(b)
r.N -= int64(n)

return n, err
}
122 changes: 109 additions & 13 deletions lib/files/socketfiles/dgram.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package socketfiles
import (
"context"
"io"
"net"
"os"
"sync"
"time"
Expand All @@ -14,12 +13,13 @@ import (
type datagramWriter struct {
*wrapper.Info

mu sync.Mutex
mu sync.Mutex
closed chan struct{}

noerrs bool
off int

buf []byte
off int

sock *socket
}
Expand Down Expand Up @@ -67,14 +67,22 @@ func (w *datagramWriter) SetPacketSize(size int) int {
w.sock.packetSize = len(w.buf)
w.sock.updateDelay(len(w.buf))

// Update filename.
w.Info.SetNameFromURL(w.sock.uri())

return prev
}

func (w *datagramWriter) SetBitrate(bitrate int) int {
w.mu.Lock()
defer w.mu.Unlock()

return w.sock.setBitrate(bitrate, len(w.buf))
prev := w.sock.setBitrate(bitrate, len(w.buf))

// Update filename.
w.Info.SetNameFromURL(w.sock.uri())

return prev
}

func (w *datagramWriter) Sync() error {
Expand Down Expand Up @@ -185,19 +193,15 @@ func (w *datagramWriter) Write(b []byte) (n int, err error) {
}

func newDatagramWriter(ctx context.Context, sock *socket) *datagramWriter {
var buf []byte
if sock.packetSize > 0 {
buf = make([]byte, sock.packetSize)
}

w := &datagramWriter{
Info: wrapper.NewInfo(sock.uri(), 0, time.Now()),
sock: sock,

closed: make(chan struct{}),
buf: buf,
}

w.SetPacketSize(sock.packetSize)

go func() {
select {
case <-w.closed:
Expand All @@ -211,16 +215,108 @@ func newDatagramWriter(ctx context.Context, sock *socket) *datagramWriter {

type datagramReader struct {
*wrapper.Info
net.Conn
sock *socket

mu sync.Mutex

buf []byte
cnt int
read int
}

// defaultMaxPacketSize is the maximum size of an IPv4 payload, and non-Jumbogram IPv6 payload.
// This is an overly safe default.
const defaultMaxPacketSize = 64 * 1024

func (r *datagramReader) SetPacketSize(size int) int {
r.mu.Lock()
defer r.mu.Unlock()

prev := len(r.buf)

if size <= 0 {
size = defaultMaxPacketSize
}

switch {
case size <= len(r.buf):
r.buf = r.buf[:size]

default:
r.buf = append(r.buf, make([]byte, size-len(r.buf))...)
}

if r.read > len(r.buf) {
r.read = len(r.buf)
}
if r.cnt > len(r.buf) {
r.cnt = len(r.buf)
}

r.sock.maxPacketSize = len(r.buf)

// Update filename.
r.Info.SetNameFromURL(r.sock.uri())

return prev
}

func (r *datagramReader) Seek(offset int64, whence int) (int64, error) {
return 0, os.ErrInvalid
}

func (r *datagramReader) Close() error {
// Do not attempt to acquire the Mutex.
// Doing so will deadlock with a concurrent blocking Read(),
// and prevent read cancellation.
return r.sock.conn.Close()
}

// ReadPacket reads a single packet from a data source.
// It is up to the caller to ensure that the given buffer is sufficient to read a full packet.
func (r *datagramReader) ReadPacket(b []byte) (n int, err error) {
return r.sock.conn.Read(b)
}

// Read performs reads from a datagram source into a continuous stream.
//
// It does this by ensuring that each read on the datagram socket is to a sufficiently sized buffer.
// If the given buffer is too small, it will read to an internal buffer with length set from max_pkt_size,
// and following reads will read from that buffer until it is empty.
//
// Properly, a datagram source should know it is reading packets,
// and ensure each given buffer is large enough to read the maximum packet size expected.
// Unfortunately, some APIs in Go can expect Read()s to operate as a continuous stream instead of packets,
// and that a short read buffer, will just leave the rest of the unread data ready to read, not dropped on the floor.
func (r *datagramReader) Read(b []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()

if r.read >= r.cnt {
// Nothing is buffered.

if len(b) >= len(r.buf) {
// The read can be done directly.
return r.ReadPacket(b)
}

// Given buffer is too small, use internal buffer.
r.read = 0 // reset read start buffer.
r.cnt, err = r.ReadPacket(r.buf)
}

n = copy(b, r.buf[r.read:r.cnt])
r.read += n
return n, err
}

func newDatagramReader(ctx context.Context, sock *socket) *datagramReader {
return &datagramReader{
r := &datagramReader{
Info: wrapper.NewInfo(sock.uri(), 0, time.Now()),
Conn: sock.conn,
sock: sock,
}

r.SetPacketSize(sock.maxPacketSize)

return r
}
37 changes: 27 additions & 10 deletions lib/files/socketfiles/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,24 @@ var (

// URL query field keys.
const (
FieldBufferSize = "buffer_size"
FieldLocalAddress = "localaddr"
FieldLocalPort = "localport"
FieldMaxBitrate = "max_bitrate"
FieldPacketSize = "pkt_size"
FieldTOS = "tos"
FieldTTL = "ttl"
FieldBufferSize = "buffer_size"
FieldLocalAddress = "localaddr"
FieldLocalPort = "localport"
FieldMaxBitrate = "max_bitrate"
FieldMaxPacketSize = "max_pkt_size"
FieldPacketSize = "pkt_size"
FieldTOS = "tos"
FieldTTL = "ttl"
)

type socket struct {
conn net.Conn

addr, qaddr net.Addr

bufferSize int
packetSize int
bufferSize int
packetSize int
maxPacketSize int

tos, ttl int

Expand Down Expand Up @@ -90,6 +92,9 @@ func (s *socket) uriQuery() url.Values {
if s.packetSize > 0 {
q.Set(FieldPacketSize, strconv.Itoa(s.packetSize))
}
if s.maxPacketSize > 0 {
q.Set(FieldMaxPacketSize, strconv.Itoa(s.maxPacketSize))
}
}

switch network {
Expand Down Expand Up @@ -127,12 +132,24 @@ func sockReader(conn net.Conn, q url.Values) (*socket, error) {
}
}

laddr := conn.LocalAddr()

var maxPacketSize int
switch laddr.Network() {
case "udp", "udp4", "udp6", "unixgram", "unixpacket":
maxPacketSize, err = getSize(q, FieldMaxPacketSize)
if err != nil {
return nil, err
}
}

return &socket{
conn: conn,

addr: conn.LocalAddr(),

bufferSize: bufferSize,
bufferSize: bufferSize,
maxPacketSize: maxPacketSize,
}, nil
}

Expand Down
Loading

0 comments on commit 01e1625

Please sign in to comment.