Skip to content

Commit

Permalink
Refactor agent to send writes from only 1 goroutine (fixes #112), als…
Browse files Browse the repository at this point in the history
…o make the logic simpler
  • Loading branch information
felipejfc committed May 4, 2020
1 parent 3d3b62a commit eee2f1c
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 207 deletions.
165 changes: 88 additions & 77 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ const handlerType = "handler"
type (
// Agent corresponds to a user and is used for storing raw Conn information
Agent struct {
Session *session.Session // session
appDieChan chan bool // app die channel
chDie chan struct{} // wait for close
chSend chan pendingMessage // push message queue
chStopHeartbeat chan struct{} // stop heartbeats
chStopWrite chan struct{} // stop writing messages
Session *session.Session // session
appDieChan chan bool // app die channel
chDie chan struct{} // wait for close
chSend chan pendingWrite // push message queue
chStopHeartbeat chan struct{} // stop heartbeats
chStopWrite chan struct{} // stop writing messages
closeMutex sync.Mutex
conn net.Conn // low-level conn fd
decoder codec.PacketDecoder // binary decoder
Expand All @@ -88,6 +88,12 @@ type (
payload interface{} // payload
err bool // if its an error message
}

pendingWrite struct {
ctx context.Context
data []byte
err error
}
)

// NewAgent create new agent instance
Expand All @@ -112,7 +118,7 @@ func NewAgent(
a := &Agent{
appDieChan: dieChan,
chDie: make(chan struct{}),
chSend: make(chan pendingMessage, messagesBufferSize),
chSend: make(chan pendingWrite, messagesBufferSize),
chStopHeartbeat: make(chan struct{}),
chStopWrite: make(chan struct{}),
messagesBufferSize: messagesBufferSize,
Expand All @@ -134,14 +140,70 @@ func NewAgent(
return a
}

func (a *Agent) send(m pendingMessage) (err error) {
func (a *Agent) getMessageFromPendingMessage(pm pendingMessage) (*message.Message, error) {
payload, err := util.SerializeOrRaw(a.serializer, pm.payload)
if err != nil {
payload, err = util.GetErrorPayload(a.serializer, err)
if err != nil {
return nil, err
}
}

// construct message and encode
m := &message.Message{
Type: pm.typ,
Data: payload,
Route: pm.route,
ID: pm.mid,
Err: pm.err,
}

return m, nil
}

func (a *Agent) packetEncodeMessage(m *message.Message) ([]byte, error) {
em, err := a.messageEncoder.Encode(m)
if err != nil {
return nil, err
}

// packet encode
p, err := a.encoder.Encode(packet.Data, em)
if err != nil {
return nil, err
}
return p, nil
}

func (a *Agent) send(pendingMsg pendingMessage) (err error) {
defer func() {
if e := recover(); e != nil {
err = errors.NewError(constants.ErrBrokenPipe, errors.ErrClientClosedRequest)
}
}()
a.reportChannelSize()
a.chSend <- m

m, err := a.getMessageFromPendingMessage(pendingMsg)
if err != nil {
return err
}

// packet encode
p, err := a.packetEncodeMessage(m)
if err != nil {
return err
}

pWrite := pendingWrite{
ctx: pendingMsg.ctx,
data: p,
}

if pendingMsg.err {
pWrite.err = util.GetErrorFromPayload(a.serializer, m.Data)
}

a.chSend <- pWrite
return
}

Expand Down Expand Up @@ -170,17 +232,11 @@ func (a *Agent) ResponseMID(ctx context.Context, mid uint, v interface{}, isErro
err = isError[0]
}
if a.GetStatus() == constants.StatusClosed {
err := errors.NewError(constants.ErrBrokenPipe, errors.ErrClientClosedRequest)
tracing.FinishSpan(ctx, err)
metrics.ReportTimingFromCtx(ctx, a.metricsReporters, handlerType, err)
return err
return errors.NewError(constants.ErrBrokenPipe, errors.ErrClientClosedRequest)
}

if mid <= 0 {
err := constants.ErrSessionOnNotify
tracing.FinishSpan(ctx, err)
metrics.ReportTimingFromCtx(ctx, a.metricsReporters, handlerType, err)
return err
return constants.ErrSessionOnNotify
}

switch d := v.(type) {
Expand Down Expand Up @@ -309,9 +365,7 @@ func (a *Agent) heartbeat() {
logger.Log.Debugf("Session heartbeat timeout, LastTime=%d, Deadline=%d", atomic.LoadInt64(&a.lastAt), deadline)
return
}
if _, err := a.conn.Write(hbd); err != nil {
return
}
a.chSend <- pendingWrite{data: hbd}
case <-a.chDie:
return
case <-a.chStopHeartbeat:
Expand Down Expand Up @@ -351,67 +405,17 @@ func (a *Agent) write() {

for {
select {
case data := <-a.chSend:
payload, err := util.SerializeOrRaw(a.serializer, data.payload)
if err != nil {
logger.Log.Errorf("Failed to serialize response: %s", err.Error())
payload, err = util.GetErrorPayload(a.serializer, err)
if err != nil {
tracing.FinishSpan(data.ctx, err)
if data.typ == message.Response {
metrics.ReportTimingFromCtx(data.ctx, a.metricsReporters, handlerType, err)
}
logger.Log.Error("cannot serialize message and respond to the client ", err.Error())
break
}
}

// construct message and encode
m := &message.Message{
Type: data.typ,
Data: payload,
Route: data.route,
ID: data.mid,
Err: data.err,
}
em, err := a.messageEncoder.Encode(m)
if err != nil {
tracing.FinishSpan(data.ctx, err)
if data.typ == message.Response {
metrics.ReportTimingFromCtx(data.ctx, a.metricsReporters, handlerType, err)
}
logger.Log.Errorf("Failed to encode message: %s", err.Error())
break
}

// packet encode
p, err := a.encoder.Encode(packet.Data, em)
if err != nil {
tracing.FinishSpan(data.ctx, err)
if data.typ == message.Response {
metrics.ReportTimingFromCtx(data.ctx, a.metricsReporters, handlerType, err)
}
logger.Log.Errorf("Failed to encode packet: %s", err.Error())
break
}
case pWrite := <-a.chSend:
// close agent if low-level Conn broken
if _, err := a.conn.Write(p); err != nil {
tracing.FinishSpan(data.ctx, err)
if data.typ == message.Response {
metrics.ReportTimingFromCtx(data.ctx, a.metricsReporters, handlerType, err)
}
logger.Log.Errorf("Failed to write response: %s", err.Error())
if _, err := a.conn.Write(pWrite.data); err != nil {
tracing.FinishSpan(pWrite.ctx, err)
metrics.ReportTimingFromCtx(pWrite.ctx, a.metricsReporters, handlerType, err)
logger.Log.Errorf("Failed to write in conn: %s", err.Error())
return
}
var e error
tracing.FinishSpan(data.ctx, e)
if data.typ == message.Response {
var rErr error
if m.Err {
rErr = util.GetErrorFromPayload(a.serializer, payload)
}
metrics.ReportTimingFromCtx(data.ctx, a.metricsReporters, handlerType, rErr)
}
tracing.FinishSpan(pWrite.ctx, e)
metrics.ReportTimingFromCtx(pWrite.ctx, a.metricsReporters, handlerType, pWrite.err)
case <-a.chStopWrite:
return
}
Expand All @@ -425,6 +429,13 @@ func (a *Agent) SendRequest(ctx context.Context, serverID, route string, v inter

// AnswerWithError answers with an error
func (a *Agent) AnswerWithError(ctx context.Context, mid uint, err error) {
var e error
defer func() {
if e != nil {
tracing.FinishSpan(ctx, e)
metrics.ReportTimingFromCtx(ctx, a.metricsReporters, handlerType, e)
}
}()
if ctx != nil && err != nil {
s := opentracing.SpanFromContext(ctx)
if s != nil {
Expand Down
Loading

0 comments on commit eee2f1c

Please sign in to comment.