Skip to content

Commit

Permalink
Fixed incorrect validation of prices
Browse files Browse the repository at this point in the history
  • Loading branch information
bsrinivas8687 committed Aug 15, 2023
1 parent f3e17df commit f302367
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 185 deletions.
80 changes: 38 additions & 42 deletions x/node/keeper/abci.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,58 +42,54 @@ func (k *Keeper) EndBlock(ctx sdk.Context) []abcitypes.ValidatorUpdate {
k.IterateNodes(ctx, func(_ int, item types.Node) bool {
k.Logger(ctx).Info("Updating prices for node", "address", item.Address)

if item.GigabytePrices != nil {
if maxGigabytePricesModified {
for _, coin := range maxGigabytePrices {
amount := item.GigabytePrices.AmountOf(coin.Denom)
if amount.GT(coin.Amount) {
item.GigabytePrices = item.GigabytePrices.Sub(
sdk.NewCoins(
sdk.NewCoin(coin.Denom, amount),
),
).Add(coin)
}
if maxGigabytePricesModified {
for _, coin := range maxGigabytePrices {
amount := item.GigabytePrices.AmountOf(coin.Denom)
if amount.GT(coin.Amount) {
item.GigabytePrices = item.GigabytePrices.Sub(
sdk.NewCoins(
sdk.NewCoin(coin.Denom, amount),
),
).Add(coin)
}
}
}

if minGigabytePricesModified {
for _, coin := range minGigabytePrices {
amount := item.GigabytePrices.AmountOf(coin.Denom)
if amount.LT(coin.Amount) {
item.GigabytePrices = item.GigabytePrices.Sub(
sdk.NewCoins(
sdk.NewCoin(coin.Denom, amount),
),
).Add(coin)
}
if minGigabytePricesModified {
for _, coin := range minGigabytePrices {
amount := item.GigabytePrices.AmountOf(coin.Denom)
if amount.LT(coin.Amount) {
item.GigabytePrices = item.GigabytePrices.Sub(
sdk.NewCoins(
sdk.NewCoin(coin.Denom, amount),
),
).Add(coin)
}
}
}

if item.HourlyPrices != nil {
if maxHourlyPricesModified {
for _, coin := range maxHourlyPrices {
amount := item.HourlyPrices.AmountOf(coin.Denom)
if amount.GT(coin.Amount) {
item.HourlyPrices = item.HourlyPrices.Sub(
sdk.NewCoins(
sdk.NewCoin(coin.Denom, amount),
),
).Add(coin)
}
if maxHourlyPricesModified {
for _, coin := range maxHourlyPrices {
amount := item.HourlyPrices.AmountOf(coin.Denom)
if amount.GT(coin.Amount) {
item.HourlyPrices = item.HourlyPrices.Sub(
sdk.NewCoins(
sdk.NewCoin(coin.Denom, amount),
),
).Add(coin)
}
}
}

if minHourlyPricesModified {
for _, coin := range minHourlyPrices {
amount := item.HourlyPrices.AmountOf(coin.Denom)
if amount.LT(coin.Amount) {
item.HourlyPrices = item.HourlyPrices.Sub(
sdk.NewCoins(
sdk.NewCoin(coin.Denom, amount),
),
).Add(coin)
}
if minHourlyPricesModified {
for _, coin := range minHourlyPrices {
amount := item.HourlyPrices.AmountOf(coin.Denom)
if amount.LT(coin.Amount) {
item.HourlyPrices = item.HourlyPrices.Sub(
sdk.NewCoins(
sdk.NewCoin(coin.Denom, amount),
),
).Add(coin)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion x/node/keeper/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (k Migrator) migrateNodes(ctx sdk.Context) error {
node := types.Node{
Address: value.Address,
GigabytePrices: value.Price,
HourlyPrices: nil,
HourlyPrices: value.Price,
RemoteURL: value.RemoteURL,
InactiveAt: time.Time{},
Status: value.Status,
Expand Down
24 changes: 11 additions & 13 deletions x/node/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,13 @@ func (k *msgServer) MsgRegister(c context.Context, msg *types.MsgRegisterRequest
ctx := sdk.UnwrapSDKContext(c)

// Check if the provided GigabytePrices are valid, if not, return an error.
if msg.GigabytePrices != nil {
if !k.IsValidGigabytePrices(ctx, msg.GigabytePrices) {
return nil, types.NewErrorInvalidPrices(msg.GigabytePrices)
}
if !k.IsValidGigabytePrices(ctx, msg.GigabytePrices) {
return nil, types.NewErrorInvalidPrices(msg.GigabytePrices)
}

// Check if the provided HourlyPrices are valid, if not, return an error.
if msg.HourlyPrices != nil {
if !k.IsValidHourlyPrices(ctx, msg.HourlyPrices) {
return nil, types.NewErrorInvalidPrices(msg.HourlyPrices)
}
if !k.IsValidHourlyPrices(ctx, msg.HourlyPrices) {
return nil, types.NewErrorInvalidPrices(msg.HourlyPrices)
}

// Convert the `msg.From` address (in Bech32 format) to an `sdk.AccAddress`.
Expand Down Expand Up @@ -121,11 +117,13 @@ func (k *msgServer) MsgUpdateDetails(c context.Context, msg *types.MsgUpdateDeta
return nil, types.NewErrorNodeNotFound(nodeAddr)
}

// Update the node's GigabytePrices and HourlyPrices with the provided values.
node.GigabytePrices = msg.GigabytePrices
node.HourlyPrices = msg.HourlyPrices

// If a RemoteURL is provided, update the node's RemoteURL as well.
// Update the node's GigabytePrices, HourlyPrices, and RemoteURL with the provided values.
if msg.GigabytePrices != nil {
node.GigabytePrices = msg.GigabytePrices
}
if msg.HourlyPrices != nil {
node.HourlyPrices = msg.HourlyPrices
}
if msg.RemoteURL != "" {
node.RemoteURL = msg.RemoteURL
}
Expand Down
14 changes: 4 additions & 10 deletions x/node/keeper/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,21 @@ func (k *Keeper) IsValidHourlyPrices(ctx sdk.Context, prices sdk.Coins) bool {
}

func (k *Keeper) IsValidSubscriptionGigabytes(ctx sdk.Context, gigabytes int64) bool {
maxGigabytes := k.MaxSubscriptionGigabytes(ctx)
if maxGigabytes != 0 && gigabytes > maxGigabytes {
if gigabytes < k.MinSubscriptionGigabytes(ctx) {
return false
}

minGigabytes := k.MinSubscriptionGigabytes(ctx)
if minGigabytes != 0 && gigabytes < minGigabytes {
if gigabytes > k.MaxSubscriptionGigabytes(ctx) {
return false
}

return true
}

func (k *Keeper) IsValidSubscriptionHours(ctx sdk.Context, hours int64) bool {
maxHours := k.MaxSubscriptionHours(ctx)
if maxHours != 0 && hours > maxHours {
if hours < k.MinSubscriptionHours(ctx) {
return false
}

minHours := k.MinSubscriptionHours(ctx)
if minHours != 0 && hours < minHours {
if hours > k.MaxSubscriptionHours(ctx) {
return false
}

Expand Down
51 changes: 27 additions & 24 deletions x/node/types/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,29 @@ func (m *MsgRegisterRequest) ValidateBasic() error {
if _, err := sdk.AccAddressFromBech32(m.From); err != nil {
return errors.Wrap(ErrorInvalidMessage, err.Error())
}
if m.GigabytePrices != nil {
if m.GigabytePrices.Len() == 0 {
return errors.Wrap(ErrorInvalidMessage, "gigabyte_prices length cannot be zero")
}
if m.GigabytePrices.IsAnyNil() {
return errors.Wrap(ErrorInvalidMessage, "gigabyte_prices cannot contain nil")
}
if !m.GigabytePrices.IsValid() {
return errors.Wrap(ErrorInvalidMessage, "gigabyte_prices must be valid")
}
if m.GigabytePrices == nil {
return errors.Wrap(ErrorInvalidMessage, "gigabyte_prices cannot be nil")
}
if m.HourlyPrices != nil {
if m.HourlyPrices.Len() == 0 {
return errors.Wrap(ErrorInvalidMessage, "hourly_prices length cannot be zero")
}
if m.HourlyPrices.IsAnyNil() {
return errors.Wrap(ErrorInvalidMessage, "hourly_prices cannot contain nil")
}
if !m.HourlyPrices.IsValid() {
return errors.Wrap(ErrorInvalidMessage, "hourly_prices must be valid")
}
if m.GigabytePrices.Len() == 0 {
return errors.Wrap(ErrorInvalidMessage, "gigabyte_prices length cannot be zero")
}
if m.GigabytePrices.IsAnyNil() {
return errors.Wrap(ErrorInvalidMessage, "gigabyte_prices cannot contain nil")
}
if !m.GigabytePrices.IsValid() {
return errors.Wrap(ErrorInvalidMessage, "gigabyte_prices must be valid")
}
if m.HourlyPrices == nil {
return errors.Wrap(ErrorInvalidMessage, "hourly_prices cannot be nil")
}
if m.HourlyPrices.Len() == 0 {
return errors.Wrap(ErrorInvalidMessage, "hourly_prices length cannot be zero")
}
if m.HourlyPrices.IsAnyNil() {
return errors.Wrap(ErrorInvalidMessage, "hourly_prices cannot contain nil")
}
if !m.HourlyPrices.IsValid() {
return errors.Wrap(ErrorInvalidMessage, "hourly_prices must be valid")
}
if m.RemoteURL == "" {
return errors.Wrap(ErrorInvalidMessage, "remote_url cannot be empty")
Expand Down Expand Up @@ -247,10 +249,11 @@ func (m *MsgSubscribeRequest) ValidateBasic() error {
return errors.Wrap(ErrorInvalidMessage, "hours cannot be negative")
}
}
if m.Denom != "" {
if err := sdk.ValidateDenom(m.Denom); err != nil {
return errors.Wrap(ErrorInvalidMessage, err.Error())
}
if m.Denom == "" {
return errors.Wrap(ErrorInvalidMessage, "denom cannot be empty")
}
if err := sdk.ValidateDenom(m.Denom); err != nil {
return errors.Wrap(ErrorInvalidMessage, err.Error())
}

return nil
Expand Down
54 changes: 24 additions & 30 deletions x/node/types/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,29 @@ func (m *Node) Validate() error {
if _, err := hubtypes.NodeAddressFromBech32(m.Address); err != nil {
return errors.Wrapf(err, "invalid address %s", m.Address)
}
if m.GigabytePrices != nil {
if m.GigabytePrices.Len() == 0 {
return fmt.Errorf("gigabyte_prices cannot be empty")
}
if m.GigabytePrices.IsAnyNil() {
return fmt.Errorf("gigabyte_prices cannot contain nil")
}
if !m.GigabytePrices.IsValid() {
return fmt.Errorf("gigabyte_prices must be valid")
}
if m.GigabytePrices == nil {
return fmt.Errorf("gigabyte_prices cannot be nil")
}
if m.HourlyPrices != nil {
if m.HourlyPrices.Len() == 0 {
return fmt.Errorf("hourly_prices cannot be empty")
}
if m.HourlyPrices.IsAnyNil() {
return fmt.Errorf("hourly_prices cannot contain nil")
}
if !m.HourlyPrices.IsValid() {
return fmt.Errorf("hourly_prices must be valid")
}
if m.GigabytePrices.Len() == 0 {
return fmt.Errorf("gigabyte_prices cannot be empty")
}
if m.GigabytePrices.IsAnyNil() {
return fmt.Errorf("gigabyte_prices cannot contain nil")
}
if !m.GigabytePrices.IsValid() {
return fmt.Errorf("gigabyte_prices must be valid")
}
if m.HourlyPrices == nil {
return fmt.Errorf("hourly_prices cannot be nil")
}
if m.HourlyPrices.Len() == 0 {
return fmt.Errorf("hourly_prices cannot be empty")
}
if m.HourlyPrices.IsAnyNil() {
return fmt.Errorf("hourly_prices cannot contain nil")
}
if !m.HourlyPrices.IsValid() {
return fmt.Errorf("hourly_prices must be valid")
}
if m.RemoteURL == "" {
return fmt.Errorf("remote_url cannot be empty")
Expand Down Expand Up @@ -91,31 +93,23 @@ func (m *Node) Validate() error {
}

func (m *Node) GigabytePrice(denom string) (sdk.Coin, bool) {
if m.GigabytePrices == nil && denom == "" {
return sdk.Coin{Amount: sdk.NewInt(0)}, true
}

for _, v := range m.GigabytePrices {
if v.Denom == denom {
return v, true
}
}

return sdk.Coin{Amount: sdk.NewInt(0)}, false
return sdk.Coin{}, false
}

func (m *Node) HourlyPrice(denom string) (sdk.Coin, bool) {
if m.HourlyPrices == nil && denom == "" {
return sdk.Coin{Amount: sdk.NewInt(0)}, true
}

for _, v := range m.HourlyPrices {
if v.Denom == denom {
return v, true
}
}

return sdk.Coin{Amount: sdk.NewInt(0)}, false
return sdk.Coin{}, false
}

type (
Expand Down
12 changes: 12 additions & 0 deletions x/node/types/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ func validateMaxSubscriptionGigabytes(v interface{}) error {
if value < 0 {
return fmt.Errorf("max_subscription_gigabytes cannot be negative")
}
if value == 0 {
return fmt.Errorf("max_subscription_gigabytes cannot be zero")
}

return nil
}
Expand All @@ -311,6 +314,9 @@ func validateMinSubscriptionGigabytes(v interface{}) error {
if value < 0 {
return fmt.Errorf("min_subscription_gigabytes cannot be negative")
}
if value == 0 {
return fmt.Errorf("min_subscription_gigabytes cannot be zero")
}

return nil
}
Expand All @@ -324,6 +330,9 @@ func validateMaxSubscriptionHours(v interface{}) error {
if value < 0 {
return fmt.Errorf("max_subscription_hours cannot be negative")
}
if value == 0 {
return fmt.Errorf("max_subscription_hours cannot be zero")
}

return nil
}
Expand All @@ -337,6 +346,9 @@ func validateMinSubscriptionHours(v interface{}) error {
if value < 0 {
return fmt.Errorf("min_subscription_hours cannot be negative")
}
if value == 0 {
return fmt.Errorf("min_subscription_hours cannot be zero")
}

return nil
}
Expand Down

0 comments on commit f302367

Please sign in to comment.