Skip to content

Commit

Permalink
Merge pull request #3451 from vegaprotocol/hotfix/fix-fees-calculatio…
Browse files Browse the repository at this point in the history
…n-lp-shares

remove float from the core shares calculation for LP + use shopsring decimal
  • Loading branch information
jeremyletang committed May 4, 2021
1 parent 92339d9 commit 04aa1cd
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 83 deletions.
53 changes: 29 additions & 24 deletions execution/equity_shares.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,50 @@ package execution

import (
"fmt"
"math/big"

"github.com/shopspring/decimal"
)

// lp holds LiquidityProvider stake and avg values
type lp struct {
stake float64
share float64
avg float64
stake decimal.Decimal
share decimal.Decimal
avg decimal.Decimal
}

// EquityShares module controls the Equity sharing algorithm described on the spec:
// https://github.com/vegaprotocol/product/blob/02af55e048a92a204e9ee7b7ae6b4475a198c7ff/specs/0042-setting-fees-and-rewarding-lps.md#calculating-liquidity-provider-equity-like-share
type EquityShares struct {
// mvp is the MarketValueProxy
mvp float64
mvp decimal.Decimal

// lps is a map of party id to lp (LiquidityProviders)
lps map[string]*lp
}

func NewEquityShares(mvp float64) *EquityShares {
func NewEquityShares(mvp decimal.Decimal) *EquityShares {
return &EquityShares{
mvp: mvp,
lps: map[string]*lp{},
}
}

func (es *EquityShares) WithMVP(mvp float64) *EquityShares {
func (es *EquityShares) WithMVP(mvp decimal.Decimal) *EquityShares {
es.mvp = mvp
return es
}

// SetPartyStake sets LP values for a given party.
func (es *EquityShares) SetPartyStake(id string, newStake float64) {
func (es *EquityShares) SetPartyStake(id string, newStakeU64 uint64) {
newStake := decimal.NewFromBigInt(new(big.Int).SetUint64(newStakeU64), 0)
v, found := es.lps[id]
// first time we set the newStake and mvp as avg.
if !found {
if newStake > 0 {
if newStake.GreaterThan(decimal.Zero) {
// if marketValueProxy == 0
// we assume mvp will be our stake?
if es.mvp == 0 {
if es.mvp.Equal(decimal.Zero) {
es.mvp = newStake
}
es.lps[id] = &lp{stake: newStake, avg: es.mvp}
Expand All @@ -51,33 +55,34 @@ func (es *EquityShares) SetPartyStake(id string, newStake float64) {
return
}

if newStake <= 0 {
if newStake.Equal(decimal.Zero) {
// We are removing an existing stake
delete(es.lps, id)
return
}

if newStake <= v.stake {
if newStake.LessThanOrEqual(v.stake) {
v.stake = newStake
return
}

// delta will allways be > 0 at this point
delta := newStake - v.stake
delta := newStake.Sub(v.stake)
eq := es.mustEquity(id)
v.avg = ((eq * v.avg) + (delta * es.mvp)) / (eq + v.stake)
// v.avg = ((eq * v.avg) + (delta * es.mvp)) / (eq + v.stake)
v.avg = (eq.Mul(v.avg).Add(delta.Mul(es.mvp))).Div(eq.Add(v.stake))
v.stake = newStake
}

// AvgEntryValuation returns the Average Entry Valuation for a given party.
func (es *EquityShares) AvgEntryValuation(id string) float64 {
func (es *EquityShares) AvgEntryValuation(id string) decimal.Decimal {
if v, ok := es.lps[id]; ok {
return v.avg
}
return 0
return decimal.Zero
}

func (es *EquityShares) mustEquity(party string) float64 {
func (es *EquityShares) mustEquity(party string) decimal.Decimal {
eq, err := es.equity(party)
if err != nil {
panic(err)
Expand All @@ -90,20 +95,20 @@ func (es *EquityShares) mustEquity(party string) float64 {
// given a party id (i).
//
// Returns an error if the party has no stake.
func (es *EquityShares) equity(id string) (float64, error) {
func (es *EquityShares) equity(id string) (decimal.Decimal, error) {
if v, ok := es.lps[id]; ok {
return (v.stake * es.mvp) / v.avg, nil
return (v.stake.Mul(es.mvp)).Div(v.avg), nil
}

return 0, fmt.Errorf("party %s has no stake", id)
return decimal.Zero, fmt.Errorf("party %s has no stake", id)
}

// Shares returns the ratio of equity for a given party
func (es *EquityShares) Shares(undeployed map[string]struct{}) map[string]float64 {
func (es *EquityShares) Shares(undeployed map[string]struct{}) map[string]decimal.Decimal {
// Calculate the equity for each party and the totalEquity (the sum of all
// the equities)
var totalEquity float64
shares := map[string]float64{}
var totalEquity decimal.Decimal
shares := map[string]decimal.Decimal{}
for id := range es.lps {
// if the party is not one of the deployed parties,
// we just skip
Expand All @@ -118,11 +123,11 @@ func (es *EquityShares) Shares(undeployed map[string]struct{}) map[string]float6
panic(err)
}
shares[id] = eq
totalEquity += eq
totalEquity = totalEquity.Add(eq)
}

for id, eq := range shares {
eqshare := eq / totalEquity
eqshare := eq.Div(totalEquity)
shares[id] = eqshare
es.lps[id].share = eqshare
}
Expand Down
49 changes: 25 additions & 24 deletions execution/equity_shares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"
"time"

"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand All @@ -22,43 +23,43 @@ func TestEquityShares(t *testing.T) {
// TestEquitySharesAverageEntryValuation is based on the spec example:
// https://github.com/vegaprotocol/product/blob/02af55e048a92a204e9ee7b7ae6b4475a198c7ff/specs/0042-setting-fees-and-rewarding-lps.md#calculating-liquidity-provider-equity-like-share
func testAverageEntryValuation(t *testing.T) {
es := execution.NewEquityShares(100)
es := execution.NewEquityShares(decimal.NewFromFloat(100.))

es.SetPartyStake("LP1", 100)
require.EqualValues(t, 100, es.AvgEntryValuation("LP1"))
es.SetPartyStake("LP1", uint64(100))
require.EqualValues(t, decimal.NewFromFloat(100.), es.AvgEntryValuation("LP1"))

es.SetPartyStake("LP1", 200)
require.EqualValues(t, 100, es.AvgEntryValuation("LP1"))
es.SetPartyStake("LP1", uint64(200))
require.True(t, decimal.NewFromFloat(100.).Equal(es.AvgEntryValuation("LP1")))

es.WithMVP(200).SetPartyStake("LP2", 200)
require.EqualValues(t, 200, es.AvgEntryValuation("LP2"))
require.EqualValues(t, 100, es.AvgEntryValuation("LP1"))
es.WithMVP(decimal.NewFromFloat(200.)).SetPartyStake("LP2", uint64(200))
require.True(t, decimal.NewFromFloat(200.).Equal(es.AvgEntryValuation("LP2")))
require.True(t, decimal.NewFromFloat(100.).Equal(es.AvgEntryValuation("LP1")))

es.WithMVP(400).SetPartyStake("LP1", 300)
require.EqualValues(t, 120, es.AvgEntryValuation("LP1"))
es.WithMVP(decimal.NewFromFloat(400.)).SetPartyStake("LP1", uint64(300))
require.True(t, decimal.NewFromFloat(120.).Equal(es.AvgEntryValuation("LP1")))

es.SetPartyStake("LP1", 1)
require.EqualValues(t, 120, es.AvgEntryValuation("LP1"))
require.EqualValues(t, 200, es.AvgEntryValuation("LP2"))
es.SetPartyStake("LP1", uint64(1))
require.True(t, decimal.NewFromFloat(120.).Equal(es.AvgEntryValuation("LP1")))
require.True(t, decimal.NewFromFloat(200.).Equal(es.AvgEntryValuation("LP2")))
}

func testShares(t *testing.T) {
var (
oneSixth = 1.0 / 6
oneThird = 1.0 / 3
oneFourth = 1.0 / 4
threeFourth = 3.0 / 4
twoThirds = 2.0 / 3
half = 1.0 / 2
oneSixth = decimal.NewFromFloat(1.0).Div(decimal.NewFromFloat(6.))
oneThird = decimal.NewFromFloat(1.0).Div(decimal.NewFromFloat(3.))
oneFourth = decimal.NewFromFloat(1.0).Div(decimal.NewFromFloat(4.))
threeFourth = decimal.NewFromFloat(3.0).Div(decimal.NewFromFloat(4.))
twoThirds = decimal.NewFromFloat(2.0).Div(decimal.NewFromFloat(3.))
half = decimal.NewFromFloat(1.0).Div(decimal.NewFromFloat(2.))
)

es := execution.NewEquityShares(100)
es := execution.NewEquityShares(decimal.NewFromFloat(100.))

// Set LP1
es.SetPartyStake("LP1", 100)
t.Run("LP1", func(t *testing.T) {
s := es.Shares(map[string]struct{}{})
assert.Equal(t, 1.0, s["LP1"])
assert.True(t, decimal.NewFromFloat(1.0).Equal(s["LP1"]))
})

// Set LP2
Expand All @@ -69,7 +70,7 @@ func testShares(t *testing.T) {

assert.Equal(t, oneThird, lp1)
assert.Equal(t, twoThirds, lp2)
assert.Equal(t, 1.0, lp1+lp2)
assert.True(t, decimal.NewFromFloat(1.0).Equal(lp1.Add(lp2)))
})

// Set LP3
Expand All @@ -82,7 +83,7 @@ func testShares(t *testing.T) {
assert.Equal(t, oneSixth, lp1)
assert.Equal(t, oneThird, lp2)
assert.Equal(t, half, lp3)
assert.Equal(t, 1.0, lp1+lp2+lp3)
assert.True(t, decimal.NewFromFloat(1.0).Equal(lp1.Add(lp2).Add(lp3)))
})

// LP2 is undeployed
Expand All @@ -97,7 +98,7 @@ func testShares(t *testing.T) {
assert.Equal(t, oneFourth, lp1)
// assert.Equal(t, oneThird, lp2)
assert.Equal(t, threeFourth, lp3)
assert.Equal(t, 1.0, lp1+lp3)
assert.True(t, decimal.NewFromFloat(1.0).Equal(lp1.Add(lp3)))
})
}

Expand Down
14 changes: 9 additions & 5 deletions execution/fees.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package execution

import (
"errors"
"math"
"math/big"
"time"

"github.com/shopspring/decimal"
)

type FeeSplitter struct {
Expand Down Expand Up @@ -51,13 +53,15 @@ func (fs *FeeSplitter) activeWindowLength(mvw time.Duration) time.Duration {

// MarketValueProxy returns the market value proxy according to the spec:
// https://github.com/vegaprotocol/product/blob/master/specs/0042-setting-fees-and-rewarding-lps.md
func (fs *FeeSplitter) MarketValueProxy(mvwl time.Duration, totalStake float64) float64 {
func (fs *FeeSplitter) MarketValueProxy(mvwl time.Duration, totalStakeU64 uint64) decimal.Decimal {
totalStake := decimal.NewFromBigInt(new(big.Int).SetUint64(totalStakeU64), 0)
// t is the distance between
awl := fs.activeWindowLength(mvwl)
if awl > 0 {
factor := mvwl.Seconds() / awl.Seconds()
tv := fs.tradeValue
return math.Max(totalStake, factor*float64(tv))
factor := decimal.NewFromFloat(mvwl.Seconds()).Div(
decimal.NewFromFloat(awl.Seconds()))
tv := decimal.NewFromBigInt(new(big.Int).SetUint64(fs.tradeValue), 0)
return decimal.Max(totalStake, factor.Mul(tv))
}
return totalStake
}
Expand Down
19 changes: 10 additions & 9 deletions execution/fees_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,40 @@ import (
"testing"
"time"

"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
)

func TestFeeSplitter(t *testing.T) {
var (
totalStake float64 = 100
timeWindowStart = time.Now()
marketValueWindowLength = 1 * time.Minute
totalStake uint64 = 100
timeWindowStart = time.Now()
marketValueWindowLength = 1 * time.Minute
)

tests := []struct {
currentTime time.Time
tradedValue uint64
expectedValueProxy float64
expectedValueProxy decimal.Decimal
}{
{
currentTime: timeWindowStart,
expectedValueProxy: 100,
expectedValueProxy: decimal.NewFromFloat(100.),
},
{
tradedValue: 10,
currentTime: timeWindowStart.Add(10 * time.Second),
expectedValueProxy: 100,
expectedValueProxy: decimal.NewFromFloat(100.),
},
{
tradedValue: 100,
currentTime: timeWindowStart.Add(30 * time.Second),
expectedValueProxy: 200,
expectedValueProxy: decimal.NewFromFloat(200.),
},
{
tradedValue: 300,
currentTime: timeWindowStart.Add(3 * marketValueWindowLength),
expectedValueProxy: 300,
expectedValueProxy: decimal.NewFromFloat(300.),
},
}

Expand All @@ -50,7 +51,7 @@ func TestFeeSplitter(t *testing.T) {
fs.AddTradeValue(test.tradedValue)

got := fs.MarketValueProxy(marketValueWindowLength, totalStake)
require.Equal(t, test.expectedValueProxy, got)
require.True(t, test.expectedValueProxy.Equal(got))
})
}
}
4 changes: 2 additions & 2 deletions execution/liquidity_provision.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func (m *Market) finalizeLiquidityProvisionAmendmentAuction(
// now we can update the liquidity fee to be taken
m.updateLiquidityFee(ctx)
// now we can setup our party stake to calculate equities
m.equityShares.SetPartyStake(party, float64(sub.CommitmentAmount))
m.equityShares.SetPartyStake(party, sub.CommitmentAmount)
// force update of shares so they are updated for all
_ = m.equityShares.Shares(m.liquidity.GetInactiveParties())

Expand Down Expand Up @@ -272,7 +272,7 @@ func (m *Market) finalizeLiquidityProvisionAmendmentContinuous(
// now we can update the liquidity fee to be taken
m.updateLiquidityFee(ctx)
// now we can setup our party stake to calculate equities
m.equityShares.SetPartyStake(party, float64(sub.CommitmentAmount))
m.equityShares.SetPartyStake(party, sub.CommitmentAmount)
// force update of shares so they are updated for all
_ = m.equityShares.Shares(m.liquidity.GetInactiveParties())

Expand Down

0 comments on commit 04aa1cd

Please sign in to comment.