Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace unsafe usage of recover() in helper functions #4913

Merged
merged 5 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions client/history/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ func (c *clientImpl) GetReplicationMessages(
for peer, req := range requestsByPeer {
peer, req := peer, req
g.Go(func() (e error) {
defer log.CapturePanic(c.logger, &e)
defer func() { log.CapturePanic(recover(), c.logger, &e) }()

requestContext, cancel := common.CreateChildContext(ctx, 0.05)
defer cancel()
Expand Down Expand Up @@ -939,7 +939,7 @@ func (c *clientImpl) CountDLQMessages(
for _, peer := range peers {
peer := peer
g.Go(func() (e error) {
defer log.CapturePanic(c.logger, &e)
defer func() { log.CapturePanic(recover(), c.logger, &e) }()

response, err := c.client.CountDLQMessages(ctx, request, append(opts, yarpc.WithShardKey(peer))...)
if err == nil {
Expand Down Expand Up @@ -1047,7 +1047,7 @@ func (c *clientImpl) NotifyFailoverMarkers(
for peer, req := range requestsByPeer {
peer, req := peer, req
g.Go(func() (e error) {
defer log.CapturePanic(c.logger, &e)
defer func() { log.CapturePanic(recover(), c.logger, &e) }()

ctx, cancel := c.createContext(ctx)
defer cancel()
Expand Down
9 changes: 5 additions & 4 deletions common/log/panic.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ import (
// If the panic value is not error then a default error is returned
// We have to use pointer is because in golang: "recover return nil if was not called directly by a deferred function."
// And we have to set the returned error otherwise our handler will return nil as error which is incorrect
// NOTE: this function MUST be called in a deferred function
func CapturePanic(logger Logger, retError *error) {
// revive:disable-next-line:defer Caller must call from a deferred function
if errPanic := recover(); errPanic != nil {
// errPanic MUST be the result from calling recover, which MUST be done in a single level deep
// deferred function. The usual way of calling this is:
// - defer func() { log.CapturePanic(recover(), logger, &err) }()
func CapturePanic(errPanic interface{}, logger Logger, retError *error) {
if errPanic != nil {
err, ok := errPanic.(error)
if !ok {
err = fmt.Errorf("panic object is not error: %#v", errPanic)
Comment on lines +34 to 41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll lightly vote for a rename, but either's clear enough I think. and there is some benefit in implying something may be panicking error objects and they're handled specially...

eh. your pick.

Suggested change
// errPanic MUST be the result from calling recover, which MUST be done in a single level deep
// deferred function. The usual way of calling this is:
// - defer func() { log.CapturePanic(recover(), logger, &err) }()
func CapturePanic(errPanic interface{}, logger Logger, retError *error) {
if errPanic != nil {
err, ok := errPanic.(error)
if !ok {
err = fmt.Errorf("panic object is not error: %#v", errPanic)
// "recovered" MUST be the result from calling recover, which MUST be done in a single level deep
// deferred function. The usual way of calling this is:
// - defer func() { log.CapturePanic(recover(), logger, &err) }()
func CapturePanic(recovered interface{}, logger Logger, retError *error) {
if recovered != nil {
err, ok := recovered.(error)
if !ok {
err = fmt.Errorf("panic object is not error: %#v", recovered)

Expand Down
46 changes: 22 additions & 24 deletions common/persistence/sql/sqlExecutionStore.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,9 @@ func (m *sqlExecutionStore) GetWorkflowExecution(
ctx context.Context,
request *p.InternalGetWorkflowExecutionRequest,
) (resp *p.InternalGetWorkflowExecutionResponse, e error) {
recoverPanic := func(err *error) {
// revive:disable-next-line:defer Func is being called using defer().
if r := recover(); r != nil {
*err = fmt.Errorf("DB operation panicked: %v %s", r, debug.Stack())
recoverPanic := func(recovered interface{}, err *error) {
if recovered != nil {
*err = fmt.Errorf("DB operation panicked: %v %s", recovered, debug.Stack())
}
}

Expand All @@ -291,55 +290,55 @@ func (m *sqlExecutionStore) GetWorkflowExecution(
g, ctx := errgroup.WithContext(ctx)

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
executions, e = m.getExecutions(ctx, request, domainID, wfID, runID)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
activityInfos, e = getActivityInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
timerInfos, e = getTimerInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
childExecutionInfos, e = getChildExecutionInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
requestCancelInfos, e = getRequestCancelInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
signalInfos, e = getSignalInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
bufferedEvents, e = getBufferedEvents(
ctx, m.db, m.shardID, domainID, wfID, runID)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
signalsRequested, e = getSignalsRequested(
ctx, m.db, m.shardID, domainID, wfID, runID)
return e
Expand Down Expand Up @@ -619,10 +618,9 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
ctx context.Context,
request *p.DeleteWorkflowExecutionRequest,
) error {
recoverPanic := func(err *error) {
// revive:disable-next-line:defer Func is being called using defer().
if r := recover(); r != nil {
*err = fmt.Errorf("DB operation panicked: %v %s", r, debug.Stack())
recoverPanic := func(recovered interface{}, err *error) {
if recovered != nil {
*err = fmt.Errorf("DB operation panicked: %v %s", recovered, debug.Stack())
}
}
domainID := serialization.MustParseUUID(request.DomainID)
Expand All @@ -631,7 +629,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
g, ctx := errgroup.WithContext(ctx)

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromExecutions(ctx, &sqlplugin.ExecutionsFilter{
ShardID: m.shardID,
DomainID: domainID,
Expand All @@ -642,7 +640,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromActivityInfoMaps(ctx, &sqlplugin.ActivityInfoMapsFilter{
ShardID: int64(m.shardID),
DomainID: domainID,
Expand All @@ -653,7 +651,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromTimerInfoMaps(ctx, &sqlplugin.TimerInfoMapsFilter{
ShardID: int64(m.shardID),
DomainID: domainID,
Expand All @@ -664,7 +662,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromChildExecutionInfoMaps(ctx, &sqlplugin.ChildExecutionInfoMapsFilter{
ShardID: int64(m.shardID),
DomainID: domainID,
Expand All @@ -675,7 +673,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromRequestCancelInfoMaps(ctx, &sqlplugin.RequestCancelInfoMapsFilter{
ShardID: int64(m.shardID),
DomainID: domainID,
Expand All @@ -686,7 +684,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromSignalInfoMaps(ctx, &sqlplugin.SignalInfoMapsFilter{
ShardID: int64(m.shardID),
DomainID: domainID,
Expand All @@ -697,7 +695,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromBufferedEvents(ctx, &sqlplugin.BufferedEventsFilter{
ShardID: m.shardID,
DomainID: domainID,
Expand All @@ -708,7 +706,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromSignalsRequestedSets(ctx, &sqlplugin.SignalsRequestedSetsFilter{
ShardID: int64(m.shardID),
DomainID: domainID,
Expand Down
Loading