Skip to content

Commit

Permalink
use read and write lock.
Browse files Browse the repository at this point in the history
  • Loading branch information
winlinvip committed Dec 2, 2015
1 parent 580e7ef commit 2469d5e
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 42 deletions.
23 changes: 18 additions & 5 deletions protocol/rtmp.go
Expand Up @@ -3260,9 +3260,6 @@ type RtmpStack struct {
// the input and output stream.
in io.Reader
out io.Writer
// the underlayer fd, for writev to use.
// @remark unix only, other os ignore it.
fd int64
// use bytes.Buffer to parse RTMP.
// TODO: FIXME: use bufio.Reader instead.
inb bytes.Buffer
Expand All @@ -3275,12 +3272,21 @@ type RtmpStack struct {
outChunkSize uint32
// whether the stack is closing.
closing bool

// the system fd, for writev to use.
// @remark unix only, other os ignore it.
sysfd interface{}
// once start the fast send mode(writev),
// we should never use slow again in stack.
// @remark maybe the fast still use slow, it's ok.
fastMode bool
}

func NewRtmpStack(r io.Reader, w io.Writer) *RtmpStack {
return &RtmpStack{
in: r,
out: w,
sysfd: nil,
chunks: make(map[uint32]*RtmpChunk),
inChunkSize: RtmpProtocolChunkSize,
outChunkSize: RtmpProtocolChunkSize,
Expand Down Expand Up @@ -3527,10 +3533,11 @@ func (v *RtmpStack) SendMessage(msgs ...*RtmpMessage) (err error) {
}

// use simple slow send when got one message to send.
if len(msgs) == 1 {
if len(msgs) == 1 && !v.fastMode {
return v.slowSendMessages(iovs...)
}

v.fastMode = true
return v.fastSendMessages(iovs...)
}

Expand All @@ -3545,9 +3552,15 @@ func (v *RtmpStack) slowSendMessages(iovs ...[]byte) (err error) {
}
}

if _, err = io.CopyN(v.out, &b, int64(b.Len())); err != nil {
var n int
bb := b.Bytes()
if n, err = v.out.Write(bb); err != nil {
return
}

if n != len(bb) {
panic(fmt.Sprintf("netFD.Write EAGAIN, n=%v, nb=%v", n, len(bb)))
}
return
}

Expand Down
194 changes: 157 additions & 37 deletions protocol/rtmp_unix.go
Expand Up @@ -25,56 +25,176 @@ package protocol

import (
"fmt"
"github.com/ossrs/go-oryx/core"
"io"
"net"
"os"
"reflect"
"sync"
"syscall"
"unsafe"
)

func (v *RtmpStack) fastSendMessages(iovs ...[]byte) (err error) {
// initialize the fd.
if v.fd == 0 {
var ok bool
var c *net.TCPConn
if c, ok = v.out.(*net.TCPConn); !ok {
return v.slowSendMessages(iovs...)
}
type RtmpSysFd struct {
// system fd, from (c *TCPConn).(fd *netFD).sysfd
sysfd uintptr
// the *netFD, from (c *TCPConn).fd
fd reflect.Value
// the pollDesc, from (c *TCPConn).(fd *netFD).pd
pd reflect.Value

var vfd reflect.Value
// get c which is net.TCPConn
if vfd = reflect.ValueOf(c); vfd.Kind() == reflect.Ptr {
vfd = vfd.Elem()
}
// get c.fd which is net.netFD, in net/net.go
if vfd = vfd.FieldByName("fd"); vfd.Kind() == reflect.Ptr {
vfd = vfd.Elem()
}
// get c.fd.sysfd which is int, in net/fd_unix.go
if vfd = vfd.FieldByName("sysfd"); vfd.Kind() == reflect.Ptr {
vfd = vfd.Elem()
}
// get fd value.
v.fd = vfd.Int()
}
// locker.
lock sync.Mutex

// rollback to slow send.
rtmp *RtmpStack
// whether ok to writev.
ok bool
}

func (v *RtmpSysFd) init() interface{} {
// use writev when got fd.
// @see https://github.com/winlinvip/vectorio/blob/master/vectorio.go
if v.fd > 0 {
var total int
var ok bool
var c *net.TCPConn
if c, ok = v.rtmp.out.(*net.TCPConn); !ok {
return v
}

iovecs := make([]syscall.Iovec, len(iovs))
for i, iov := range iovs {
total += len(iov)
iovecs[i] = syscall.Iovec{&iov[0], uint64(len(iov))}
}
// get c which is net.TCPConn
var fc reflect.Value
if fc = reflect.ValueOf(c); fc.Kind() == reflect.Ptr {
fc = fc.Elem()
}

// get the ptr.
v.fd = fc.FieldByName("fd")

// get c.fd which is net.netFD, in net/net.go
var ffd reflect.Value = v.fd
if ffd.Kind() == reflect.Ptr {
ffd = ffd.Elem()
}

// get the ptr.
if v.pd = ffd.FieldByName("pd"); v.pd.Kind() != reflect.Ptr {
v.pd = v.pd.Addr()
}

// get c.fd.pd which is pollDesc, in net/fd_poll_runtime.go
var fpd reflect.Value = v.pd
if fpd.Kind() == reflect.Ptr {
fpd = fpd.Elem()
}

// get c.fd.sysfd which is int, in net/fd_unix.go
var fsysfd reflect.Value
if fsysfd = ffd.FieldByName("sysfd"); fsysfd.Kind() == reflect.Ptr {
fsysfd = fsysfd.Elem()
}
v.sysfd = uintptr(fsysfd.Int())

// fast writev is ok.
v.ok = true

return v
}

// delegate v.fd.writeLock
func (v *RtmpSysFd) writeLock() (err error) {
v.lock.Lock()
return
}

// delegate v.fd.writeUnlock
func (v *RtmpSysFd) writeUnlock() {
v.lock.Unlock()
return
}

// delegate v.pd.PrepareWrite
func (v *RtmpSysFd) PrepareWrite() (err error) {
return
}

// delegate v.pd.WaitWrite
func (v *RtmpSysFd) WaitWrite() (err error) {
// error: reflect.Value.Call using value obtained using unexported field
return
}

func (v *RtmpSysFd) writev(iovs ...[]byte) (err error) {
if !v.ok {
return v.rtmp.slowSendMessages(iovs...)
}

// lock the fd.
if err = v.writeLock(); err != nil {
return
}
defer v.writeUnlock()
if err = v.PrepareWrite(); err != nil {
return
}

// prepare data.
var total int
iovecs := make([]syscall.Iovec, len(iovs))
for i, iov := range iovs {
total += len(iov)
iovecs[i] = syscall.Iovec{&iov[0], uint64(len(iov))}
}

var nn int
for {
var n int
if n, err = writev(uintptr(v.fd), iovecs); err != nil {
if n, err = writev(v.sysfd, iovecs); err != nil {
return
} else if n != total {
}
// TODO: FIXME: implements is.
if n != total {
core.Error.Println("fatal.")
panic(fmt.Sprintf("writev n=%v, total=%v", n, total))
}
return

if n > 0 {
nn += n
}
if nn == total {
break
}
if err == syscall.EAGAIN {
if err = v.WaitWrite(); err == nil {

This comment has been minimized.

Copy link
@winlinvip

winlinvip Dec 2, 2015

Author Member

这个方法是没法调用的,因为pollDesc是没有导出的struct。

continue
}
}
if err != nil {
break
}
if n == 0 {
err = io.ErrUnexpectedEOF
break
}
}

if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError("writev", err)
}

return
}

func (v *RtmpStack) fastSendMessages(iovs ...[]byte) (err error) {
// initialize the fd.
if v.sysfd == nil {
fd := &RtmpSysFd{
rtmp: v,
}
v.sysfd = fd.init()
}

if v, ok := v.sysfd.(*RtmpSysFd); ok {
return v.writev(iovs...)
}

return v.slowSendMessages(iovs...)
Expand All @@ -84,10 +204,10 @@ func writev(fd uintptr, iovs []syscall.Iovec) (int, error) {
iovsPtr := uintptr(unsafe.Pointer(&iovs[0]))
iovsLen := uintptr(len(iovs))

n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, fd, iovsPtr, iovsLen)
n, _, e0 := syscall.Syscall(syscall.SYS_WRITEV, fd, iovsPtr, iovsLen)

if errno != 0 {
return 0, fmt.Errorf("writev failed, errno=%v", int64(errno))
if e0 != 0 {
return 0, syscall.Errno(e0)
}

return int(n), nil
Expand Down

0 comments on commit 2469d5e

Please sign in to comment.