Skip to content

Commit

Permalink
Fix panic case for GetTotalRewards and add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
jununifi committed Jul 9, 2024
1 parent 67b5957 commit 84ec203
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 40 deletions.
39 changes: 8 additions & 31 deletions x/liquiditypool/keeper/accumulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,6 @@ func (k Keeper) NewPositionIntervalAccumulation(ctx context.Context, accumName,
return k.SetAccumulator(ctx, accumulator)
}

func (k Keeper) AddToPosition(ctx context.Context, accumName, name string, newShares math.LegacyDec) error {
accumulator, err := k.GetAccumulator(ctx, accumName)
if err != nil {
return err
}
return k.AddToPositionIntervalAccumulation(ctx, accumName, name, newShares, accumulator.AccumValue)
}

func (k Keeper) AddToPositionIntervalAccumulation(ctx context.Context, accumName, name string, newShares math.LegacyDec, intervalAccumulationPerShare sdk.DecCoins) error {
if !newShares.IsPositive() {
return errors.New("Adding non-positive number of shares to position")
Expand Down Expand Up @@ -195,14 +187,6 @@ func (k Keeper) AddToPositionIntervalAccumulation(ctx context.Context, accumName
return k.SetAccumulator(ctx, accumulator)
}

func (k Keeper) RemoveFromPosition(ctx context.Context, accumName, name string, numSharesToRemove math.LegacyDec) error {
accumulator, err := k.GetAccumulator(ctx, accumName)
if err != nil {
return err
}
return k.RemoveFromPositionIntervalAccumulation(ctx, accumName, name, numSharesToRemove, accumulator.AccumValue)
}

func (k Keeper) RemoveFromPositionIntervalAccumulation(ctx context.Context, accumName, name string, numSharesToRemove math.LegacyDec, intervalAccumulationPerShare sdk.DecCoins) error {
if !numSharesToRemove.IsPositive() {
return fmt.Errorf("Removing non-positive shares (%s)", numSharesToRemove)
Expand Down Expand Up @@ -342,24 +326,17 @@ func (k Keeper) ClaimRewards(ctx context.Context, accumName, positionName string
return truncatedRewardsTotal, dust, nil
}

func (k Keeper) AddToUnclaimedRewards(ctx context.Context, accumName, positionName string, rewardsToAddTotal sdk.DecCoins) error {
position, err := k.GetAccumulatorPosition(ctx, accumName, positionName)
if err != nil {
return err
}

if rewardsToAddTotal.IsAnyNegative() {
return types.ErrNegRewardAddition
}

k.SetAccumulatorPosition(ctx, accumName, position.AccumValuePerShare, positionName, position.NumShares, position.UnclaimedRewardsTotal.Add(rewardsToAddTotal...))

return nil
}

func GetTotalRewards(accumulator types.AccumulatorObject, position types.AccumulatorPosition) sdk.DecCoins {
totalRewards := position.UnclaimedRewardsTotal

if !position.NumShares.IsPositive() {
return sdk.DecCoins{}
}
for _, coin := range position.AccumValuePerShare {
if accumulator.AccumValue.AmountOf(coin.Denom).LT(coin.Amount) {
return sdk.DecCoins{}
}
}
accumRewards := accumulator.AccumValue.Sub(position.AccumValuePerShare).MulDec(position.NumShares)
totalRewards = totalRewards.Add(accumRewards...)

Expand Down
109 changes: 100 additions & 9 deletions x/liquiditypool/keeper/accumulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,12 @@ import (
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/stretchr/testify/require"
keepertest "github.com/sunriselayer/sunrise/testutil/keeper"
"github.com/sunriselayer/sunrise/x/liquiditypool/keeper"
"github.com/sunriselayer/sunrise/x/liquiditypool/types"
)

// TODO: add test for AddToPosition
// TODO: add test for AddToPositionIntervalAccumulation
// TODO: add test for RemoveFromPosition
// TODO: add test for RemoveFromPositionIntervalAccumulation
// TODO: add test for UpdatePositionIntervalAccumulation
// TODO: add test for SetPositionIntervalAccumulation
// TODO: add test for DeletePosition
// TODO: add test for deletePosition
// TODO: add test for ClaimRewards
// TODO: add test for AddToUnclaimedRewards
// TODO: add test for GetTotalRewards

func TestAccumulatorStore(t *testing.T) {
k, _, ctx := keepertest.LiquiditypoolKeeper(t)
Expand Down Expand Up @@ -143,3 +136,101 @@ func TestNewPositionIntervalAccumulation(t *testing.T) {
require.Equal(t, position.AccumValuePerShare.String(), "1.000000000000000000denom")
require.Equal(t, position.UnclaimedRewardsTotal.String(), "")
}

func TestAddToPositionIntervalAccumulation(t *testing.T) {
k, _, ctx := keepertest.LiquiditypoolKeeper(t)
// when new shares is negative
accmulatorValuePerShare := sdk.NewDecCoins(sdk.NewDecCoin("denom", math.NewInt(1)))
err := k.AddToPositionIntervalAccumulation(ctx, "accumulator", "index", math.LegacyOneDec().Neg(), accmulatorValuePerShare)
require.Error(t, err)

// when position does not exist
err = k.AddToPositionIntervalAccumulation(ctx, "accumulator", "index", math.LegacyOneDec(), accmulatorValuePerShare)
require.Error(t, err)

// when accumulator and position exists
err = k.InitAccumulator(ctx, "accumulator")
require.NoError(t, err)

accumulator, err := k.GetAccumulator(ctx, "accumulator")
require.NoError(t, err)
accumulator.AccumValue = accumulator.AccumValue.Add(accmulatorValuePerShare...).Add(accmulatorValuePerShare...)
err = k.SetAccumulator(ctx, accumulator)
require.NoError(t, err)

err = k.NewPositionIntervalAccumulation(ctx, "accumulator", "index", math.LegacyOneDec(), accmulatorValuePerShare)
require.NoError(t, err)
err = k.AddToPositionIntervalAccumulation(ctx, "accumulator", "index", math.LegacyOneDec(), accmulatorValuePerShare)
require.NoError(t, err)

// check accumulator change
accumulator, err = k.GetAccumulator(ctx, "accumulator")
require.NoError(t, err)
require.Equal(t, accumulator.Name, "accumulator")
require.Equal(t, accumulator.AccumValue.String(), "2.000000000000000000denom")
require.Equal(t, accumulator.TotalShares.String(), "2.000000000000000000")

// check accumulator position change
position, err := k.GetAccumulatorPosition(ctx, "accumulator", "index")
require.NoError(t, err)
require.Equal(t, position.Name, "accumulator")
require.Equal(t, position.Index, "index")
require.Equal(t, position.NumShares.String(), "2.000000000000000000")
require.Equal(t, position.AccumValuePerShare.String(), "1.000000000000000000denom")
require.Equal(t, position.UnclaimedRewardsTotal.String(), "1.000000000000000000denom")
}

func TestGetTotalRewards(t *testing.T) {
// When accumulator value is lower than position value
oneDecCoins := sdk.NewDecCoins(sdk.NewDecCoin("denom", math.NewInt(1)))
twoDecCoins := sdk.NewDecCoins(sdk.NewDecCoin("denom", math.NewInt(2)))
emptyDecCoins := sdk.NewDecCoins()
rewards := keeper.GetTotalRewards(types.AccumulatorObject{
AccumValue: oneDecCoins,
}, types.AccumulatorPosition{
AccumValuePerShare: twoDecCoins,
NumShares: math.LegacyOneDec(),
UnclaimedRewardsTotal: emptyDecCoins,
})
require.Equal(t, rewards.String(), "")

// When accumulator value is equal to position value
rewards = keeper.GetTotalRewards(types.AccumulatorObject{
AccumValue: oneDecCoins,
}, types.AccumulatorPosition{
AccumValuePerShare: oneDecCoins,
NumShares: math.LegacyOneDec(),
UnclaimedRewardsTotal: emptyDecCoins,
})
require.Equal(t, rewards.String(), "")

// When accumulator value is greater than position value
rewards = keeper.GetTotalRewards(types.AccumulatorObject{
AccumValue: twoDecCoins,
}, types.AccumulatorPosition{
AccumValuePerShare: oneDecCoins,
NumShares: math.LegacyOneDec(),
UnclaimedRewardsTotal: emptyDecCoins,
})
require.Equal(t, rewards.String(), "1.000000000000000000denom")

// When position numShares is zero
rewards = keeper.GetTotalRewards(types.AccumulatorObject{
AccumValue: twoDecCoins,
}, types.AccumulatorPosition{
AccumValuePerShare: oneDecCoins,
NumShares: math.LegacyZeroDec(),
UnclaimedRewardsTotal: emptyDecCoins,
})
require.Equal(t, rewards.String(), "")

// When position numShares is negative
rewards = keeper.GetTotalRewards(types.AccumulatorObject{
AccumValue: twoDecCoins,
}, types.AccumulatorPosition{
AccumValuePerShare: oneDecCoins,
NumShares: math.LegacyZeroDec().Neg(),
UnclaimedRewardsTotal: emptyDecCoins,
})
require.Equal(t, rewards.String(), "")
}

0 comments on commit 84ec203

Please sign in to comment.