Skip to content

Commit

Permalink
Add context wrapper for Conn.Read
Browse files Browse the repository at this point in the history
Add wrapper to cancel Read by context.
Note that underlying Conn must support SetDeadline.
  • Loading branch information
at-wat committed Mar 5, 2020
1 parent 5bc0a3f commit 6bf60d2
Show file tree
Hide file tree
Showing 2 changed files with 353 additions and 0 deletions.
156 changes: 156 additions & 0 deletions internal/net/connctx/connctx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Package connctx wraps net.Conn using context.Context.
package connctx

import (
"context"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"time"
)

// ErrClosing is returned on Write to closed connection.
var ErrClosing = errors.New("use of closed network connection")

// ConnCtx is a wrapper of net.Conn using context.Context.
type ConnCtx interface {
Read(context.Context, []byte) (int, error)
Write(context.Context, []byte) (int, error)
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
Conn() net.Conn
}

type connCtx struct {
nextConn net.Conn
closed chan struct{}
closeOnce sync.Once
readMu sync.Mutex
writeMu sync.Mutex
}

var veryOld = time.Unix(0, 1)

// New creates a new ConnCtx wrapping given net.Conn.
func New(conn net.Conn) ConnCtx {
c := &connCtx{
nextConn: conn,
closed: make(chan struct{}),
}
return c
}

func (c *connCtx) Read(ctx context.Context, b []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()

select {
case <-c.closed:
return 0, io.EOF
default:
}

done := make(chan struct{})
var wg sync.WaitGroup
var errSetDeadline atomic.Value
wg.Add(1)
go func() {
defer wg.Done()
select {
case <-ctx.Done():
// context canceled
if err := c.nextConn.SetReadDeadline(veryOld); err != nil {
errSetDeadline.Store(err)
return
}
<-done
if err := c.nextConn.SetReadDeadline(time.Time{}); err != nil {
errSetDeadline.Store(err)
}
case <-done:
}
}()

n, err := c.nextConn.Read(b)

close(done)
wg.Wait()
if e := ctx.Err(); e != nil && n == 0 {
err = e
}
if err2 := errSetDeadline.Load(); err == nil && err2 != nil {
err = err2.(error)
}
return n, err
}

func (c *connCtx) Write(ctx context.Context, b []byte) (int, error) {
c.writeMu.Lock()
defer c.writeMu.Unlock()

select {
case <-c.closed:
return 0, ErrClosing
default:
}

done := make(chan struct{})
var wg sync.WaitGroup
var errSetDeadline atomic.Value
wg.Add(1)
go func() {
select {
case <-ctx.Done():
// context canceled
if err := c.nextConn.SetWriteDeadline(veryOld); err != nil {
errSetDeadline.Store(err)
return
}
<-done
if err := c.nextConn.SetWriteDeadline(time.Time{}); err != nil {
errSetDeadline.Store(err)
}
case <-done:
}
wg.Done()
}()

n, err := c.nextConn.Write(b)

close(done)
wg.Wait()
if e := ctx.Err(); e != nil && n == 0 {
err = e
}
if err2 := errSetDeadline.Load(); err == nil && err2 != nil {
err = err2.(error)
}
return n, err
}

func (c *connCtx) Close() error {
err := c.nextConn.Close()
c.closeOnce.Do(func() {
c.writeMu.Lock()
c.readMu.Lock()
close(c.closed)
c.readMu.Unlock()
c.writeMu.Unlock()
})
return err
}

func (c *connCtx) LocalAddr() net.Addr {
return c.nextConn.LocalAddr()
}

func (c *connCtx) RemoteAddr() net.Addr {
return c.nextConn.LocalAddr()
}

func (c *connCtx) Conn() net.Conn {
return c.nextConn
}
197 changes: 197 additions & 0 deletions internal/net/connctx/connctx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package connctx

import (
"bytes"
"context"
"io"
"net"
"testing"
"time"
)

func TestRead(t *testing.T) {
ca, cb := net.Pipe()
defer func() {
_ = ca.Close()
}()

data := []byte{0x01, 0x02, 0xFF}
chErr := make(chan error)

go func() {
_, err := cb.Write(data)
chErr <- err
}()

c := New(ca)
b := make([]byte, 100)
n, err := c.Read(context.Background(), b)
if err != nil {
t.Fatal(err)
}
if n != len(data) {
t.Errorf("Wrong data length, expected %d, got %d", len(data), n)
}
if !bytes.Equal(data, b[:n]) {
t.Errorf("Wrong data, expected %v, got %v", data, b)
}

err = <-chErr
if err != nil {
t.Fatal(err)
}
}

func TestReadTImeout(t *testing.T) {
ca, _ := net.Pipe()
defer func() {
_ = ca.Close()
}()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()

c := New(ca)
b := make([]byte, 100)
n, err := c.Read(ctx, b)
if err == nil {
t.Error("Read unexpectedly successed")
}
if n != 0 {
t.Errorf("Wrong data length, expected %d, got %d", 0, n)
}
}

func TestReadCancel(t *testing.T) {
ca, _ := net.Pipe()
defer func() {
_ = ca.Close()
}()

ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(10 * time.Millisecond)
cancel()
}()

c := New(ca)
b := make([]byte, 100)
n, err := c.Read(ctx, b)
if err == nil {
t.Error("Read unexpectedly successed")
}
if n != 0 {
t.Errorf("Wrong data length, expected %d, got %d", 0, n)
}
}

func TestReadClosed(t *testing.T) {
ca, _ := net.Pipe()

c := New(ca)
_ = c.Close()

b := make([]byte, 100)
n, err := c.Read(context.Background(), b)
if err != io.EOF {
t.Errorf("Expected error '%v', got '%v'", io.EOF, err)
}
if n != 0 {
t.Errorf("Wrong data length, expected %d, got %d", 0, n)
}
}

func TestWrite(t *testing.T) {
ca, cb := net.Pipe()
defer func() {
_ = ca.Close()
}()

chErr := make(chan error)
chRead := make(chan []byte)

go func() {
b := make([]byte, 100)
n, err := cb.Read(b)
chErr <- err
chRead <- b[:n]
}()

c := New(ca)
data := []byte{0x01, 0x02, 0xFF}
n, err := c.Write(context.Background(), data)
if err != nil {
t.Fatal(err)
}
if n != len(data) {
t.Errorf("Wrong data length, expected %d, got %d", len(data), n)
}

err = <-chErr
b := <-chRead
if !bytes.Equal(data, b) {
t.Errorf("Wrong data, expected %v, got %v", data, b)
}
if err != nil {
t.Fatal(err)
}
}

func TestWriteTimeout(t *testing.T) {
ca, _ := net.Pipe()
defer func() {
_ = ca.Close()
}()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()

c := New(ca)
b := make([]byte, 100)
n, err := c.Write(ctx, b)
if err == nil {
t.Error("Write unexpectedly successed")
}
if n != 0 {
t.Errorf("Wrong data length, expected %d, got %d", 0, n)
}
}

func TestWriteCancel(t *testing.T) {
ca, _ := net.Pipe()
defer func() {
_ = ca.Close()
}()

ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(10 * time.Millisecond)
cancel()
}()

c := New(ca)
b := make([]byte, 100)
n, err := c.Write(ctx, b)
if err == nil {
t.Error("Write unexpectedly successed")
}
if n != 0 {
t.Errorf("Wrong data length, expected %d, got %d", 0, n)
}
}

func TestWriteClosed(t *testing.T) {
ca, _ := net.Pipe()

c := New(ca)
_ = c.Close()

b := make([]byte, 100)
n, err := c.Write(context.Background(), b)
if err != ErrClosing {
t.Errorf("Expected error '%v', got '%v'", ErrClosing, err)
}
if n != 0 {
t.Errorf("Wrong data length, expected %d, got %d", 0, n)
}
}

0 comments on commit 6bf60d2

Please sign in to comment.