Skip to content

Commit

Permalink
feat: fix & rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
karlem committed May 22, 2024
1 parent 11e45a1 commit 07c1455
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 97 deletions.
7 changes: 4 additions & 3 deletions core/rewards/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"code.vegaprotocol.io/vega/core/events"
"code.vegaprotocol.io/vega/core/types"
"code.vegaprotocol.io/vega/core/vesting"
"code.vegaprotocol.io/vega/libs/num"
"code.vegaprotocol.io/vega/logging"
"code.vegaprotocol.io/vega/protos/vega"
Expand Down Expand Up @@ -89,7 +90,7 @@ type Teams interface {

type Vesting interface {
AddReward(party, asset string, amount *num.Uint, lockedForEpochs uint64)
GetRewardBonusMultiplier(party string) (num.Decimal, num.Decimal)
GetSingleAndSummedRewardBonusMultipliers(party string) (vesting.MultiplierAndQuantBalance, vesting.MultiplierAndQuantBalance)
}

type ActivityStreak interface {
Expand Down Expand Up @@ -357,8 +358,8 @@ func (e *Engine) convertTakerFeesToRewardAsset(takerFees map[string]*num.Uint, f

func (e *Engine) getRewardMultiplierForParty(party string) num.Decimal {
asMultiplier := e.activityStreak.GetRewardsDistributionMultiplier(party)
_, vsMultiplier := e.vesting.GetRewardBonusMultiplier(party)
return asMultiplier.Mul(vsMultiplier)
_, summed := e.vesting.GetSingleAndSummedRewardBonusMultipliers(party)
return asMultiplier.Mul(summed.Multiplier)
}

func filterEligible(ps []*types.PartyContributionScore) []*types.PartyContributionScore {
Expand Down
17 changes: 9 additions & 8 deletions core/rewards/mocks/mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

98 changes: 37 additions & 61 deletions core/vesting/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package vesting

import (
"context"
"fmt"
"sort"
"time"

Expand Down Expand Up @@ -67,7 +68,7 @@ type PartyRewards struct {
Vesting map[string]*num.Uint
}

type multiplierAndQuantBalance struct {
type MultiplierAndQuantBalance struct {
Multiplier num.Decimal
QuantumBalance num.Decimal
}
Expand All @@ -91,8 +92,7 @@ type Engine struct {
parties Parties

// cache the reward bonus multiplier and quantum balance
// across all summed party-owned reward accounts (including derived keys) by party.
summedRewardBonusMultiplierCache map[string]multiplierAndQuantBalance
rewardBonusMultiplierCache map[string]MultiplierAndQuantBalance
}

func New(
Expand All @@ -106,14 +106,14 @@ func New(
log = log.Named(namedLogger)

return &Engine{
log: log,
c: c,
asvm: asvm,
broker: broker,
assets: assets,
parties: parties,
state: map[string]*PartyRewards{},
summedRewardBonusMultiplierCache: map[string]multiplierAndQuantBalance{},
log: log,
c: c,
asvm: asvm,
broker: broker,
assets: assets,
parties: parties,
state: map[string]*PartyRewards{},
rewardBonusMultiplierCache: map[string]MultiplierAndQuantBalance{},
}
}

Expand Down Expand Up @@ -153,7 +153,6 @@ func (e *Engine) OnEpochEvent(ctx context.Context, epoch types.Epoch) {
if epoch.Action == proto.EpochAction_EPOCH_ACTION_END {
e.moveLocked()
e.distributeVested(ctx)
e.clearMultiplierCache()
e.broadcastVestingStatsUpdate(ctx, epoch.Seq)
e.broadcastSummary(ctx, epoch.Seq)
e.clearState()
Expand Down Expand Up @@ -193,69 +192,46 @@ func (e *Engine) rewardBonusMultiplier(quantumBalance num.Decimal) num.Decimal {
return multiplier
}

func (e *Engine) getSummedBonusMultiplier(keys []string) multiplierAndQuantBalance {
summed := multiplierAndQuantBalance{}

for _, key := range keys {
quantumBalanceForKey := e.c.GetAllVestingQuantumBalance(key)
if quantumBalanceForKey.IsZero() {
continue
}

summed.QuantumBalance = summed.QuantumBalance.Add(quantumBalanceForKey)
}

summed.Multiplier = e.rewardBonusMultiplier(summed.QuantumBalance)

return summed
}

// GetRewardBonusMultiplier returns the reward bonus multiplier and quantum balance sum for a party and it's derived keys.
func (e *Engine) GetRewardBonusMultiplier(party string) (num.Decimal, num.Decimal) {
owner := party

partyID, derivedKeys := e.parties.RelatedKeys(party)
if partyID != nil {
owner = partyID.String()
}

if cached, ok := e.summedRewardBonusMultiplierCache[owner]; ok {
return cached.QuantumBalance, cached.Multiplier
}

summed := e.getSummedBonusMultiplier(append(derivedKeys, owner))

e.summedRewardBonusMultiplierCache[owner] = summed

return summed.QuantumBalance, summed.Multiplier
}

// GetSingleAndSummedRewardBonusMultipliers returns a single and summed reward bonus multipliers and quantum balances for a party.
// The single multiplier is calculated based on the quantum balance of the party.
// The summed multiplier is calculated based on the quantum balance of the party and all derived keys.
// Caches the summed multiplier and quantum balance for the party.
func (e *Engine) GetSingleAndSummedRewardBonusMultipliers(party string) (
single multiplierAndQuantBalance,
summed multiplierAndQuantBalance,
) {
func (e *Engine) GetSingleAndSummedRewardBonusMultipliers(party string) (MultiplierAndQuantBalance, MultiplierAndQuantBalance) {
owner := party

partyID, derivedKeys := e.parties.RelatedKeys(party)
if partyID != nil {
owner = partyID.String()
}

single = e.getSummedBonusMultiplier([]string{party})
ownerKey := fmt.Sprintf("owner-%s", owner)

if cached, ok := e.summedRewardBonusMultiplierCache[owner]; ok {
return single, cached
}
summed, foundSummed := e.rewardBonusMultiplierCache[ownerKey]

summed = e.getSummedBonusMultiplier(append(derivedKeys, owner))
for _, key := range append(derivedKeys, owner) {
single, foundSingle := e.rewardBonusMultiplierCache[key]
if !foundSingle {
quantumBalanceForKey := e.c.GetAllVestingQuantumBalance(key)
if quantumBalanceForKey.IsZero() {
continue
}

single.QuantumBalance = quantumBalanceForKey
single.Multiplier = e.rewardBonusMultiplier(quantumBalanceForKey)
e.rewardBonusMultiplierCache[key] = single
}

e.summedRewardBonusMultiplierCache[owner] = summed
if !foundSummed {
summed.QuantumBalance = summed.QuantumBalance.Add(single.QuantumBalance)
}
}

if !foundSummed {
summed.Multiplier = e.rewardBonusMultiplier(summed.QuantumBalance)
e.rewardBonusMultiplierCache[ownerKey] = summed
}

return single, summed
return e.rewardBonusMultiplierCache[party], e.rewardBonusMultiplierCache[ownerKey]
}

func (e *Engine) getPartyRewards(party string) *PartyRewards {
Expand Down Expand Up @@ -420,7 +396,7 @@ func (e *Engine) clearState() {
}

func (e *Engine) clearMultiplierCache() {
e.summedRewardBonusMultiplierCache = map[string]multiplierAndQuantBalance{}
e.rewardBonusMultiplierCache = map[string]MultiplierAndQuantBalance{}
}

func (e *Engine) broadcastSummary(ctx context.Context, seq uint64) {
Expand Down
45 changes: 20 additions & 25 deletions core/vesting/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,6 @@ func TestDistributeWithNoDelay(t *testing.T) {
t.Run("First reward payment", func(t *testing.T) {
epochSeq += 1

v.GetRewardBonusMultiplier(party)

expectLedgerMovements(t, v)

v.broker.EXPECT().Send(gomock.Any()).Do(func(evt events.Event) {
Expand Down Expand Up @@ -1284,33 +1282,22 @@ func TestGetRewardBonusMultiplier(t *testing.T) {
}

for _, key := range append(derivedKeys, party) {
balance, multiplier := v.GetRewardBonusMultiplier(key)
require.Equal(t, num.DecimalFromInt64(1500), balance)
require.Equal(t, num.DecimalFromInt64(3), multiplier)
_, summed := v.GetSingleAndSummedRewardBonusMultipliers(key)
require.Equal(t, num.DecimalFromInt64(1500), summed.QuantumBalance)
require.Equal(t, num.DecimalFromInt64(3), summed.Multiplier)
}

// check that we only called the GetVestingQuantumBalance once for each key
// later calls should be cached
require.Equal(t, 5, v.col.GetVestingQuantumBalanceCallCount())

v.col.ResetVestingQuantumBalanceCallCount()

single, summed := v.GetSingleAndSummedRewardBonusMultipliers(party)
require.Equal(t, num.DecimalFromInt64(500), single.QuantumBalance)
require.Equal(t, num.DecimalFromInt64(2), single.Multiplier)
require.Equal(t, num.DecimalFromInt64(1500), summed.QuantumBalance)
require.Equal(t, num.DecimalFromInt64(3), summed.Multiplier)

for _, key := range derivedKeys {
single, summed := v.GetSingleAndSummedRewardBonusMultipliers(key)
require.Equal(t, num.DecimalFromInt64(250), single.QuantumBalance)
require.Equal(t, num.DecimalFromInt64(1), single.Multiplier)
for _, key := range append(derivedKeys, party) {
_, summed := v.GetSingleAndSummedRewardBonusMultipliers(key)
require.Equal(t, num.DecimalFromInt64(1500), summed.QuantumBalance)
require.Equal(t, num.DecimalFromInt64(3), summed.Multiplier)
}

// here we check that we just called 5 more times to re-calculate single values
// summed should be cached
// all the calls above should be served from cache
require.Equal(t, 5, v.col.GetVestingQuantumBalanceCallCount())

v.broker.EXPECT().Send(gomock.Any()).AnyTimes()
Expand All @@ -1325,9 +1312,9 @@ func TestGetRewardBonusMultiplier(t *testing.T) {
v.col.ResetVestingQuantumBalanceCallCount()

for _, key := range append(derivedKeys, party) {
balance, multiplier := v.GetRewardBonusMultiplier(key)
require.Equal(t, num.DecimalFromInt64(1500), balance)
require.Equal(t, num.DecimalFromInt64(3), multiplier)
_, summed := v.GetSingleAndSummedRewardBonusMultipliers(key)
require.Equal(t, num.DecimalFromInt64(1500), summed.QuantumBalance)
require.Equal(t, num.DecimalFromInt64(3), summed.Multiplier)
}

// now it's called 5 times again because the cache gets reset at the end of the epoch
Expand All @@ -1341,9 +1328,17 @@ func TestGetRewardBonusMultiplier(t *testing.T) {
v.col.ResetVestingQuantumBalanceCallCount()

for _, key := range append(derivedKeys, party) {
balance, multiplier := v.GetRewardBonusMultiplier(key)
require.Equal(t, num.DecimalFromInt64(1500), balance)
require.Equal(t, num.DecimalFromInt64(3), multiplier)
single, summed := v.GetSingleAndSummedRewardBonusMultipliers(key)
require.Equal(t, num.DecimalFromInt64(1500), summed.QuantumBalance)
require.Equal(t, num.DecimalFromInt64(3), summed.Multiplier)

if key == party {
require.Equal(t, num.DecimalFromInt64(500), single.QuantumBalance)
require.Equal(t, num.DecimalFromInt64(2), single.Multiplier)
} else {
require.Equal(t, num.DecimalFromInt64(250), single.QuantumBalance)
require.Equal(t, num.DecimalFromInt64(1), single.Multiplier)
}
}

// now it's called 5 times again because the cache gets reset at the end of the epoch
Expand Down

0 comments on commit 07c1455

Please sign in to comment.