Skip to content

Commit

Permalink
Add type hints to rotation_history
Browse files Browse the repository at this point in the history
  • Loading branch information
DimaStebaev committed May 2, 2024
1 parent 49583c0 commit 4469203
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 75 deletions.
19 changes: 3 additions & 16 deletions skale/contracts/manager/node_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from __future__ import annotations
import logging
import functools
from typing import TYPE_CHECKING, List, TypedDict
from dataclasses import dataclass
from typing import TYPE_CHECKING, List

from eth_typing import ChecksumAddress

Expand All @@ -32,6 +31,7 @@

from skale.contracts.skale_manager_contract import SkaleManagerContract
from skale.types.node import NodeId
from skale.types.rotation import Rotation, RotationSwap
from skale.types.schain import SchainHash, SchainName

if TYPE_CHECKING:
Expand All @@ -44,19 +44,6 @@
NO_PREVIOUS_NODE_EXCEPTION_TEXT = 'No previous node'


@dataclass
class Rotation:
leaving_node_id: int
new_node_id: int
freeze_until: int
rotation_counter: int


class RotationSwap(TypedDict):
schain_id: SchainHash
finished_rotation: int


class NodeRotation(SkaleManagerContract):
"""Wrapper for NodeRotation.sol functions"""

Expand Down Expand Up @@ -107,7 +94,7 @@ def is_rotation_active(self, schain_name: SchainName) -> bool:
return self.is_rotation_in_progress(schain_name) and not finish_ts_reached

def is_finish_ts_reached(self, schain_name: SchainName) -> bool:
rotation = self.skale.node_rotation.get_rotation_obj(schain_name)
rotation = self.skale.node_rotation.get_rotation(schain_name)
schain_finish_ts = self.get_schain_finish_ts(rotation.leaving_node_id, schain_name)

if not schain_finish_ts:
Expand Down
77 changes: 34 additions & 43 deletions skale/schain_config/rotation_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,57 +19,39 @@

from __future__ import annotations
import logging
from collections import namedtuple
from typing import TYPE_CHECKING, Optional, TypedDict
from typing import TYPE_CHECKING, Dict, List, TypedDict

from skale import Skale
from skale.contracts.manager.node_rotation import Rotation
from skale.types.rotation import BlsPublicKey, NodesGroup, RotationNodeData

if TYPE_CHECKING:
from skale.contracts.manager.key_storage import G2Point
from skale.skale_manager import SkaleManager
from skale.types.node import NodeId
from skale.types.schain import SchainName

logger = logging.getLogger(__name__)

RotationNodeData = namedtuple('RotationNodeData', ['index', 'node_id', 'public_key'])


class NodesSwap(TypedDict):
leaving_node_id: int
new_node_id: int


class BlsPublicKey(TypedDict):
blsPublicKey0: str
blsPublicKey1: str
blsPublicKey2: str
blsPublicKey3: str


class NodesGroup(TypedDict):
rotation: NodesSwap | None
nodes: dict[int, RotationNodeData]
finish_ts: int | None
bls_public_key: BlsPublicKey | None


def get_previous_schain_groups(
skale: Skale,
schain_name: str,
leaving_node_id=None,
) -> dict:
skale: SkaleManager,
schain_name: SchainName,
leaving_node_id: NodeId | None = None,
) -> Dict[int, NodesGroup]:
"""
Returns all previous node groups with public keys and finish timestamps.
In case of no rotations returns the current state.
"""
logger.info(f'Collecting rotation history for {schain_name}...')
node_groups: dict[int, NodesGroup] = {}

group_id = skale.schains.name_to_group_id(schain_name)
group_id = skale.schains.name_to_id(schain_name)

previous_public_keys = skale.key_storage.get_all_previous_public_keys(group_id)
current_public_key = skale.key_storage.get_common_public_key(group_id)

rotation = skale.node_rotation.get_rotation_obj(schain_name)
rotation = skale.node_rotation.get_rotation(schain_name)

logger.info(f'Rotation data for {schain_name}: {rotation}')

Expand All @@ -93,9 +75,9 @@ def _add_current_schain_state(
skale: Skale,
node_groups: dict[int, NodesGroup],
rotation: Rotation,
schain_name: str,
schain_name: SchainName,
current_public_key: G2Point
):
) -> None:
"""
Internal function, composes the initial info about the current sChain state and adds it to the
node_groups dictionary
Expand All @@ -115,18 +97,23 @@ def _add_current_schain_state(


def _add_previous_schain_rotations_state(
skale: Skale,
node_groups: dict,
skale: SkaleManager,
node_groups: dict[int, NodesGroup],
rotation: Rotation,
schain_name: str,
previous_public_keys: list,
leaving_node_id=None
):
schain_name: SchainName,
previous_public_keys: list[G2Point],
leaving_node_id: NodeId | None = None
) -> None:
"""
Internal function, handles rotations from (rotation_counter - 2) to 0 and adds them to the
node_groups dictionary
"""
previous_nodes = {}

class PreviousNodeData(TypedDict):
finish_ts: int
previous_node_id: NodeId

previous_nodes: Dict[NodeId, PreviousNodeData] = {}

for rotation_id in range(rotation.rotation_counter - 1, -1, -1):
nodes = node_groups[rotation_id + 1]['nodes'].copy()
Expand All @@ -136,7 +123,7 @@ def _add_previous_schain_rotations_state(
if previous_node is not None:
finish_ts = skale.node_rotation.get_schain_finish_ts(previous_node, schain_name)
previous_nodes[node_id] = {
'finish_ts': finish_ts,
'finish_ts': finish_ts or 0,
'previous_node_id': previous_node
}

Expand Down Expand Up @@ -182,7 +169,7 @@ def _add_previous_schain_rotations_state(
break


def _pop_previous_bls_public_key(previous_public_keys):
def _pop_previous_bls_public_key(previous_public_keys: List[G2Point]) -> BlsPublicKey | None:
"""
Returns BLS public key for the group and removes it from the list, returns None if node
with provided node_id was kicked out of the chain because of failed DKG.
Expand All @@ -193,7 +180,7 @@ def _pop_previous_bls_public_key(previous_public_keys):
return bls_keys


def _compose_bls_public_key_info(bls_public_key: G2Point) -> Optional[BlsPublicKey]:
def _compose_bls_public_key_info(bls_public_key: G2Point) -> BlsPublicKey | None:
if bls_public_key:
return {
'blsPublicKey0': str(bls_public_key[0][0]),
Expand All @@ -204,10 +191,14 @@ def _compose_bls_public_key_info(bls_public_key: G2Point) -> Optional[BlsPublicK
return None


def get_new_nodes_list(skale: Skale, name: str, node_groups) -> list:
def get_new_nodes_list(
skale: SkaleManager,
name: SchainName,
node_groups: Dict[int, NodesGroup]
) -> list[NodeId]:
"""Returns list of new nodes in for the latest rotation"""
logger.info(f'Getting new nodes list for chain {name}')
rotation = skale.node_rotation.get_rotation_obj(name)
rotation = skale.node_rotation.get_rotation(name)
current_group_ids = node_groups[rotation.rotation_counter]['nodes'].keys()
new_nodes = []
for index in node_groups:
Expand Down
96 changes: 91 additions & 5 deletions skale/skale_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,27 @@
from skale.utils.helper import get_contracts_info

if TYPE_CHECKING:
from skale.contracts.manager.node_rotation import NodeRotation
from skale.contracts.manager.schains import SChains
from skale.contracts.manager.schains_internal import SChainsInternal
from skale.contracts.manager.token import Token
from skale.contracts.manager import BountyV2
from skale.contracts.manager import ConstantsHolder
from skale.contracts.manager import ContractManager
from skale.contracts.manager import DelegationController
from skale.contracts.manager import DelegationPeriodManager
from skale.contracts.manager import Distributor
from skale.contracts.manager import DKG
from skale.contracts.manager import KeyStorage
from skale.contracts.manager import Manager
from skale.contracts.manager import NodeRotation
from skale.contracts.manager import Nodes
from skale.contracts.manager import Punisher
from skale.contracts.manager import SChains
from skale.contracts.manager import SChainsInternal
from skale.contracts.manager import SlashingTable
from skale.contracts.manager import SyncManager
from skale.contracts.manager import TimeHelpersWithDebug
from skale.contracts.manager import Token
from skale.contracts.manager import TokenState
from skale.contracts.manager import ValidatorService
from skale.contracts.manager import Wallets


logger = logging.getLogger(__name__)
Expand All @@ -42,7 +59,8 @@ class SkaleManager(SkaleBase):
def project_name(self) -> str:
return 'skale-manager'

def contracts_info(self) -> List[ContractInfo[SkaleManager]]:
@staticmethod
def contracts_info() -> List[ContractInfo[SkaleManager]]:
import skale.contracts.manager as contracts
return [
ContractInfo('contract_manager', 'ContractManager',
Expand Down Expand Up @@ -88,10 +106,54 @@ def contracts_info(self) -> List[ContractInfo[SkaleManager]]:
contracts.TimeHelpersWithDebug, ContractTypes.API, False)
]

@property
def bounty_v2(self) -> BountyV2:
return cast('BountyV2', self._get_contract('bounty_v2'))

@property
def constants_holder(self) -> ConstantsHolder:
return cast('ConstantsHolder', self._get_contract('constants_holder'))

@property
def contract_manager(self) -> ContractManager:
return cast('ContractManager', self._get_contract('contract_manager'))

@property
def delegation_controller(self) -> DelegationController:
return cast('DelegationController', self._get_contract('delegation_controller'))

@property
def delegation_period_manager(self) -> DelegationPeriodManager:
return cast('DelegationPeriodManager', self._get_contract('delegation_period_manager'))

@property
def distributor(self) -> Distributor:
return cast('Distributor', self._get_contract('distributor'))

@property
def dkg(self) -> DKG:
return cast('DKG', self._get_contract('dkg'))

@property
def key_storage(self) -> KeyStorage:
return cast('KeyStorage', self._get_contract('key_storage'))

@property
def manager(self) -> Manager:
return cast('Manager', self._get_contract('manager'))

@property
def node_rotation(self) -> NodeRotation:
return cast('NodeRotation', self._get_contract('node_rotation'))

@property
def nodes(self) -> Nodes:
return cast('Nodes', self._get_contract('nodes'))

@property
def punisher(self) -> Punisher:
return cast('Punisher', self._get_contract('punisher'))

@property
def schains(self) -> SChains:
return cast('SChains', self._get_contract('schains'))
Expand All @@ -100,10 +162,34 @@ def schains(self) -> SChains:
def schains_internal(self) -> SChainsInternal:
return cast('SChainsInternal', self._get_contract('schains_internal'))

@property
def slashing_table(self) -> SlashingTable:
return cast('SlashingTable', self._get_contract('slashing_table'))

@property
def sync_manager(self) -> SyncManager:
return cast('SyncManager', self._get_contract('sync_manager'))

@property
def time_helpers_with_debug(self) -> TimeHelpersWithDebug:
return cast('TimeHelpersWithDebug', self._get_contract('time_helpers_with_debug'))

@property
def token(self) -> Token:
return cast('Token', self._get_contract('token'))

@property
def token_state(self) -> TokenState:
return cast('TokenState', self._get_contract('token_state'))

@property
def validator_service(self) -> ValidatorService:
return cast('ValidatorService', self._get_contract('validator_service'))

@property
def wallets(self) -> Wallets:
return cast('Wallets', self._get_contract('wallets'))

def init_contract_manager(self) -> None:
from skale.contracts.manager.contract_manager import ContractManager
self.add_lib_contract('contract_manager', ContractManager, 'ContractManager')
Expand Down
6 changes: 4 additions & 2 deletions skale/types/dkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from collections import namedtuple
from typing import List, NamedTuple, NewType, Tuple

from eth_typing import HexStr


Fp2Point = namedtuple('Fp2Point', ['a', 'b'])

Expand All @@ -33,5 +35,5 @@ class G2Point(NamedTuple):


class KeyShare(NamedTuple):
publicKey: Tuple[bytes, bytes]
share: bytes
publicKey: Tuple[bytes | HexStr, bytes | HexStr]
share: bytes | HexStr
Loading

0 comments on commit 4469203

Please sign in to comment.