Permalink
Browse files

Simplified PipeConns - now they properly handle the case when reader …

…side is closed
  • Loading branch information...
valyala committed Jun 10, 2016
1 parent 7c6a327 commit 80af8b2b977d7c365fd22bf4573e5d2673c5870d
Showing with 97 additions and 133 deletions.
  1. +64 −123 fasthttputil/pipeconns.go
  2. +4 −4 fasthttputil/pipeconns_test.go
  3. +29 −6 stream_test.go
View
@@ -10,14 +10,16 @@ import (
// NewPipeConns returns new bi-directonal connection pipe.
func NewPipeConns() *PipeConns {
ch1 := acquirePipeChan()
ch2 := acquirePipeChan()
pc := &PipeConns{}
pc.c1.r = ch1
pc.c1.w = ch2
pc.c2.r = ch2
pc.c2.w = ch1
ch1 := make(chan *byteBuffer, 4)
ch2 := make(chan *byteBuffer, 4)
pc := &PipeConns{
stopCh: make(chan struct{}),
}
pc.c1.rCh = ch1
pc.c1.wCh = ch2
pc.c2.rCh = ch2
pc.c2.wCh = ch1
pc.c1.pc = pc
pc.c2.pc = pc
return pc
@@ -35,8 +37,10 @@ func NewPipeConns() *PipeConns {
// * It buffers Write calls, so there is no need to have concurrent goroutine
// calling Read in order to unblock each Write call.
type PipeConns struct {
c1 pipeConn
c2 pipeConn
c1 pipeConn
c2 pipeConn
stopCh chan struct{}
stopChLock sync.Mutex
}
// Conn1 returns the first end of bi-directional pipe.
@@ -55,46 +59,49 @@ func (pc *PipeConns) Conn2() net.Conn {
return &pc.c2
}
func (pc *PipeConns) release() {
pc.c1.wlock.Lock()
pc.c2.wlock.Lock()
mustRelease := pc.c1.wclosed && pc.c2.wclosed
pc.c1.wlock.Unlock()
pc.c2.wlock.Unlock()
if mustRelease {
pc.c1.release()
pc.c2.release()
// Close closes pipe connections.
func (pc *PipeConns) Close() error {
pc.stopChLock.Lock()
select {
case <-pc.stopCh:
default:
close(pc.stopCh)
}
pc.stopChLock.Unlock()
return nil
}
type pipeConn struct {
r *pipeChan
w *pipeChan
b *byteBuffer
bb []byte
rlock sync.Mutex
rclosed bool
wlock sync.Mutex
wclosed bool
pc *PipeConns
rCh chan *byteBuffer
wCh chan *byteBuffer
pc *PipeConns
}
func (c *pipeConn) Write(p []byte) (int, error) {
b := acquireByteBuffer()
b.b = append(b.b[:0], p...)
c.wlock.Lock()
if c.wclosed {
c.wlock.Unlock()
select {
case <-c.pc.stopCh:
releaseByteBuffer(b)
return 0, errConnectionClosed
default:
}
select {
case c.wCh <- b:
default:
select {
case c.wCh <- b:
case <-c.pc.stopCh:
releaseByteBuffer(b)
return 0, errConnectionClosed
}
}
c.w.ch <- b
c.wlock.Unlock()
return len(p), nil
}
@@ -120,39 +127,35 @@ func (c *pipeConn) Read(p []byte) (int, error) {
func (c *pipeConn) read(p []byte, mayBlock bool) (int, error) {
if len(c.bb) == 0 {
c.rlock.Lock()
if err := c.readNextByteBuffer(mayBlock); err != nil {
return 0, err
}
}
n := copy(p, c.bb)
c.bb = c.bb[n:]
releaseByteBuffer(c.b)
c.b = nil
return n, nil
}
if c.rclosed {
c.rlock.Unlock()
return 0, io.EOF
}
func (c *pipeConn) readNextByteBuffer(mayBlock bool) error {
releaseByteBuffer(c.b)
c.b = nil
if mayBlock {
c.b = <-c.r.ch
} else {
select {
case c.b = <-c.r.ch:
default:
c.rlock.Unlock()
return 0, errWouldBlock
}
select {
case c.b = <-c.rCh:
default:
if !mayBlock {
return errWouldBlock
}
if c.b == nil {
c.rclosed = true
c.rlock.Unlock()
return 0, io.EOF
select {
case c.b = <-c.rCh:
case <-c.pc.stopCh:
return io.EOF
}
c.bb = c.b.b
c.rlock.Unlock()
}
n := copy(p, c.bb)
c.bb = c.bb[n:]
return n, nil
c.bb = c.b.b
return nil
}
var (
@@ -162,42 +165,7 @@ var (
)
func (c *pipeConn) Close() error {
c.wlock.Lock()
if c.wclosed {
c.wlock.Unlock()
return errConnectionClosed
}
c.wclosed = true
c.w.ch <- nil
c.wlock.Unlock()
c.pc.release()
return nil
}
func (c *pipeConn) release() {
c.rlock.Lock()
releaseByteBuffer(c.b)
c.b = nil
if !c.rclosed {
c.rclosed = true
for b := range c.r.ch {
releaseByteBuffer(b)
if b == nil {
break
}
}
}
if c.r != nil {
releasePipeChan(c.r)
c.r = nil
c.w = nil
}
c.rlock.Unlock()
return c.pc.Close()
}
func (c *pipeConn) LocalAddr() net.Addr {
@@ -251,30 +219,3 @@ var byteBufferPool = &sync.Pool{
}
},
}
func acquirePipeChan() *pipeChan {
ch := pipeChanPool.Get().(*pipeChan)
if len(ch.ch) > 0 {
panic("BUG: non-empty pipeChan acquired")
}
return ch
}
func releasePipeChan(ch *pipeChan) {
if len(ch.ch) > 0 {
panic("BUG: non-empty pipeChan released")
}
pipeChanPool.Put(ch)
}
var pipeChanPool = &sync.Pool{
New: func() interface{} {
return &pipeChan{
ch: make(chan *byteBuffer, 4),
}
},
}
type pipeChan struct {
ch chan *byteBuffer
}
@@ -208,11 +208,11 @@ func testPipeConnsClose(t *testing.T, c1, c2 net.Conn) {
// attempt closing already closed conns
for i := 0; i < 10; i++ {
if err := c1.Close(); err == nil {
t.Fatalf("expecting error")
if err := c1.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if err := c2.Close(); err == nil {
t.Fatalf("expecting error")
if err := c2.Close(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
}
View
@@ -47,21 +47,26 @@ func TestStreamReaderClose(t *testing.T) {
return
}
fmt.Fprintf(w, "the second line must fail")
data := createFixedBody(4000)
for i := 0; i < 100; i++ {
w.Write(data)
}
if err := w.Flush(); err == nil {
ch <- fmt.Errorf("expecting error on the second flush")
}
ch <- nil
})
result := firstLine + "the"
buf := make([]byte, len(result))
buf := make([]byte, len(firstLine))
n, err := io.ReadFull(r, buf)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n != len(buf) {
t.Fatalf("unexpected number of bytes read: %d. Expecting %d", n, len(buf))
}
if string(buf) != result {
t.Fatalf("unexpected result: %q. Expecting %q", buf, result)
if string(buf) != firstLine {
t.Fatalf("unexpected result: %q. Expecting %q", buf, firstLine)
}
if err := r.Close(); err != nil {
@@ -74,6 +79,24 @@ func TestStreamReaderClose(t *testing.T) {
t.Fatalf("error returned from stream reader: %s", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout")
t.Fatalf("timeout when waiting for stream reader")
}
// read trailing data
go func() {
if _, err := ioutil.ReadAll(r); err != nil {
ch <- fmt.Errorf("unexpected error when reading trailing data: %s", err)
return
}
ch <- nil
}()
select {
case err := <-ch:
if err != nil {
t.Fatalf("error returned when reading tail data: %s", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout when reading tail data")
}
}

0 comments on commit 80af8b2

Please sign in to comment.