diff --git a/raiden/message_handler.py b/raiden/message_handler.py index 7877f498f2..e08ec5a377 100644 --- a/raiden/message_handler.py +++ b/raiden/message_handler.py @@ -1,3 +1,5 @@ +from math import inf + import structlog from eth_utils import to_hex from gevent import joinall @@ -42,6 +44,7 @@ ) from raiden.transfer.utils import decrypt_secret from raiden.transfer.views import TransferRole +from raiden.utils.formatting import to_checksum_address from raiden.utils.transfers import random_secret from raiden.utils.typing import ( TYPE_CHECKING, @@ -51,6 +54,7 @@ Set, TargetAddress, Tuple, + Union, ) if TYPE_CHECKING: @@ -104,18 +108,21 @@ def on_messages(self, raiden: "RaidenService", messages: List[Message]) -> None: # (an asynchronous network is assumed) This reduces latency when a # balance proof is considered invalid because of a race with the # blockchain view of each node. - def by_canonical_identifier(state_change: StateChange) -> Tuple[int, int]: + def by_canonical_identifier( + state_change: StateChange, + ) -> Union[Tuple[int, int], Tuple[float, float]]: if isinstance(state_change, BalanceProofStateChange): balance_proof = state_change.balance_proof return ( balance_proof.canonical_identifier.channel_identifier, balance_proof.nonce, ) - + elif isinstance(state_change, ReceiveSecretReveal): + # ReceiveSecretReveal depends on other state changes happening first. + return inf, inf return 0, 0 all_state_changes.sort(key=by_canonical_identifier) - raiden.handle_and_track_state_changes(all_state_changes) @staticmethod @@ -343,10 +350,20 @@ def handle_message_lockedtransfer( if encrypted_secret is not None: try: secret = decrypt_secret(encrypted_secret, raiden.rpc_client.privkey) - log.info(f"Using encrypted secret received from {sender.hex()}") - return [ReceiveSecretReveal(secret=secret, sender=message.sender)] + log.info("Using encrypted secret", sender=to_checksum_address(sender)) + return [ + ActionInitTarget( + from_hop=from_hop, + transfer=from_transfer, + balance_proof=balance_proof, + sender=sender, + received_valid_secret=True, + ), + ReceiveSecretReveal(secret=secret, sender=message.sender), + ] except InvalidSecret: - log.error(f"Ignoring invalid encrypted secret received from {sender.hex()}") + sender_addr = to_checksum_address(sender) + log.error("Ignoring invalid encrypted secret", sender=sender_addr) return [ ActionInitTarget( from_hop=from_hop, diff --git a/raiden/tests/integration/api/rest/test_rest.py b/raiden/tests/integration/api/rest/test_rest.py index 6f2832c8ea..82e96f1ada 100644 --- a/raiden/tests/integration/api/rest/test_rest.py +++ b/raiden/tests/integration/api/rest/test_rest.py @@ -1,6 +1,7 @@ import datetime import json from http import HTTPStatus +from unittest.mock import Mock, patch import gevent import grequests @@ -11,6 +12,7 @@ from raiden.api.python import RaidenAPI from raiden.api.rest import APIServer from raiden.constants import BLOCK_ID_LATEST, Environment +from raiden.exceptions import InvalidSecret from raiden.messages.transfers import LockedTransfer, Unlock from raiden.raiden_service import RaidenService from raiden.settings import ( @@ -937,7 +939,10 @@ def test_channel_events_raiden( @pytest.mark.parametrize("number_of_nodes", [3]) @pytest.mark.parametrize("channels_per_node", [CHAIN]) @pytest.mark.parametrize("enable_rest_api", [True]) -def test_pending_transfers_endpoint(raiden_network: List[RaidenService], token_addresses): +@patch("raiden.message_handler.decrypt_secret", side_effect=InvalidSecret) +def test_pending_transfers_endpoint( + decrypt_patch: Mock, raiden_network: List[RaidenService], token_addresses +): initiator, mediator, target = raiden_network token_address = token_addresses[0] token_network_address = views.get_token_network_address_by_token_address( diff --git a/raiden/tests/integration/api/test_pythonapi.py b/raiden/tests/integration/api/test_pythonapi.py index b86d8702fb..485df247a5 100644 --- a/raiden/tests/integration/api/test_pythonapi.py +++ b/raiden/tests/integration/api/test_pythonapi.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import gevent import pytest from eth_utils import to_canonical_address @@ -12,6 +14,7 @@ InsufficientEth, InsufficientGasReserve, InvalidBinaryAddress, + InvalidSecret, InvalidSettleTimeout, RaidenRecoverableError, SamePeerAddress, @@ -353,15 +356,18 @@ def test_payment_timing_out_if_partner_does_not_respond( # pylint: disable=unus assert isinstance(app1.raiden_event_handler, HoldRaidenEventHandler), msg app1.raiden_event_handler.hold(SendSecretRequest, {}) - greenlet = gevent.spawn( - RaidenAPI(app0).transfer_and_wait, - app0.default_registry.address, - token_address, - 1, - target=app1.address, - ) - waiting.wait_for_block(app0, app1.get_block_number() + 2 * reveal_timeout + 1, retry_timeout) - greenlet.join(timeout=5) + with patch("raiden.message_handler.decrypt_secret", side_effect=InvalidSecret): + greenlet = gevent.spawn( + RaidenAPI(app0).transfer_and_wait, + app0.default_registry.address, + token_address, + 1, + target=app1.address, + ) + waiting.wait_for_block( + app0, app1.get_block_number() + 2 * reveal_timeout + 1, retry_timeout + ) + greenlet.join(timeout=5) assert not greenlet.value diff --git a/raiden/tests/integration/long_running/test_settlement.py b/raiden/tests/integration/long_running/test_settlement.py index 5bb71e1d65..a053f3131c 100644 --- a/raiden/tests/integration/long_running/test_settlement.py +++ b/raiden/tests/integration/long_running/test_settlement.py @@ -1,4 +1,5 @@ import random +from unittest.mock import patch import gevent import pytest @@ -8,7 +9,7 @@ from raiden import waiting from raiden.api.python import RaidenAPI from raiden.constants import BLOCK_ID_LATEST, EMPTY_SIGNATURE, UINT64_MAX -from raiden.exceptions import RaidenUnrecoverableError +from raiden.exceptions import InvalidSecret, RaidenUnrecoverableError from raiden.messages.transfers import LockedTransfer, LockExpired, RevealSecret, Unlock from raiden.messages.withdraw import WithdrawExpired from raiden.raiden_service import RaidenService @@ -241,15 +242,16 @@ def test_lock_expiry( LockExpired, {"secrethash": transfer_1_secrethash} ) - alice_app.mediated_transfer_async( - token_network_address=token_network_address, - amount=alice_to_bob_amount, - target=target, - identifier=identifier, - secret=transfer_1_secret, - route_states=[create_route_state_for_route([alice_app, bob_app], token_address)], - ) - transfer1_received.wait() + with patch("raiden.message_handler.decrypt_secret", side_effect=InvalidSecret): + alice_app.mediated_transfer_async( + token_network_address=token_network_address, + amount=alice_to_bob_amount, + target=target, + identifier=identifier, + secret=transfer_1_secret, + route_states=[create_route_state_for_route([alice_app, bob_app], token_address)], + ) + transfer1_received.wait() alice_bob_channel_state = get_channelstate(alice_app, bob_app, token_network_address) lock = channel.get_lock(alice_bob_channel_state.our_state, transfer_1_secrethash) @@ -292,15 +294,16 @@ def test_lock_expiry( hold_event_handler.hold_secretrequest_for(secrethash=transfer_2_secrethash) - alice_app.mediated_transfer_async( - token_network_address=token_network_address, - amount=alice_to_bob_amount, - target=target, - identifier=identifier, - secret=transfer_2_secret, - route_states=[create_route_state_for_route([alice_app, bob_app], token_address)], - ) - transfer2_received.wait() + with patch("raiden.message_handler.decrypt_secret", side_effect=InvalidSecret): + alice_app.mediated_transfer_async( + token_network_address=token_network_address, + amount=alice_to_bob_amount, + target=target, + identifier=identifier, + secret=transfer_2_secret, + route_states=[create_route_state_for_route([alice_app, bob_app], token_address)], + ) + transfer2_received.wait() # Make sure the other transfer still exists alice_chain_state = views.state_from_raiden(alice_app) @@ -356,16 +359,17 @@ def test_batch_unlock( secret_request_event = hold_event_handler.hold_secretrequest_for(secrethash=secrethash) - alice_app.mediated_transfer_async( - token_network_address=token_network_address, - amount=PaymentAmount(alice_to_bob_amount), - target=TargetAddress(bob_address), - identifier=PaymentID(identifier), - secret=secret, - route_states=[create_route_state_for_route([alice_app, bob_app], token_address)], - ) + with patch("raiden.message_handler.decrypt_secret", side_effect=InvalidSecret): + alice_app.mediated_transfer_async( + token_network_address=token_network_address, + amount=PaymentAmount(alice_to_bob_amount), + target=TargetAddress(bob_address), + identifier=PaymentID(identifier), + secret=secret, + route_states=[create_route_state_for_route([alice_app, bob_app], token_address)], + ) - secret_request_event.get() # wait for the messages to be exchanged + secret_request_event.get() # wait for the messages to be exchanged alice_bob_channel_state = get_channelstate(alice_app, bob_app, token_network_address) lock = channel.get_lock(alice_bob_channel_state.our_state, secrethash) @@ -699,16 +703,17 @@ def test_settled_lock( secret_available = hold_event_handler.hold_secretrequest_for(secrethash=secrethash) - app0.mediated_transfer_async( - token_network_address=token_network_address, - amount=amount, - target=target, - identifier=identifier, - secret=secret, - route_states=[create_route_state_for_route([app0, app1], token_address)], - ) + with patch("raiden.message_handler.decrypt_secret", side_effect=InvalidSecret): + app0.mediated_transfer_async( + token_network_address=token_network_address, + amount=amount, + target=target, + identifier=identifier, + secret=secret, + route_states=[create_route_state_for_route([app0, app1], token_address)], + ) - secret_available.wait() # wait for the messages to be exchanged + secret_available.wait() # wait for the messages to be exchanged # Save the pending locks from the pending transfer, used to test the unlock channelstate_0_1 = get_channelstate(app0, app1, token_network_address) @@ -716,14 +721,15 @@ def test_settled_lock( hold_event_handler.release_secretrequest_for(app1, secrethash) - transfer( - initiator_app=app0, - target_app=app1, - token_address=token_address, - amount=amount, - identifier=PaymentID(2), - routes=[[app0, app1]], - ) + with patch("raiden.message_handler.decrypt_secret", side_effect=InvalidSecret): + transfer( + initiator_app=app0, + target_app=app1, + token_address=token_address, + amount=amount, + identifier=PaymentID(2), + routes=[[app0, app1]], + ) # The channel state has to be recovered before the settlement, otherwise # the object is cleared from the node's state. @@ -1077,26 +1083,27 @@ def test_batch_unlock_after_restart( secrethash=bob_transfer_secrethash ) - alice_app.mediated_transfer_async( - token_network_address=token_network_address, - amount=alice_to_bob_amount, - target=TargetAddress(bob_app.address), - identifier=identifier, - secret=alice_transfer_secret, - route_states=[create_route_state_for_route([alice_app, bob_app], token_address)], - ) + with patch("raiden.message_handler.decrypt_secret", side_effect=InvalidSecret): + alice_app.mediated_transfer_async( + token_network_address=token_network_address, + amount=alice_to_bob_amount, + target=TargetAddress(bob_app.address), + identifier=identifier, + secret=alice_transfer_secret, + route_states=[create_route_state_for_route([alice_app, bob_app], token_address)], + ) - bob_app.mediated_transfer_async( - token_network_address=token_network_address, - amount=alice_to_bob_amount, - target=TargetAddress(alice_app.address), - identifier=PaymentID(identifier + 1), - secret=bob_transfer_secret, - route_states=[create_route_state_for_route([bob_app, alice_app], token_address)], - ) + bob_app.mediated_transfer_async( + token_network_address=token_network_address, + amount=alice_to_bob_amount, + target=TargetAddress(alice_app.address), + identifier=PaymentID(identifier + 1), + secret=bob_transfer_secret, + route_states=[create_route_state_for_route([bob_app, alice_app], token_address)], + ) - alice_transfer_hold.wait(timeout=timeout) - bob_transfer_hold.wait(timeout=timeout) + alice_transfer_hold.wait(timeout=timeout) + bob_transfer_hold.wait(timeout=timeout) alice_bob_channel_state = get_channelstate(alice_app, bob_app, token_network_address) alice_lock = channel.get_lock(alice_bob_channel_state.our_state, alice_transfer_secrethash) diff --git a/raiden/tests/integration/test_regression_parity.py b/raiden/tests/integration/test_regression_parity.py index cb3dbe9a4d..d0f5433c11 100644 --- a/raiden/tests/integration/test_regression_parity.py +++ b/raiden/tests/integration/test_regression_parity.py @@ -1,10 +1,12 @@ import math from typing import cast +from unittest.mock import Mock, patch import pytest from raiden import waiting from raiden.constants import BLOCK_ID_LATEST +from raiden.exceptions import InvalidSecret from raiden.network.rpc.client import JSONRPCClient from raiden.raiden_service import RaidenService from raiden.settings import DEFAULT_NUMBER_OF_BLOCK_CONFIRMATIONS @@ -41,7 +43,9 @@ @raise_on_failure @pytest.mark.parametrize("number_of_nodes", [2]) @pytest.mark.parametrize("blockchain_extra_config", [STATE_PRUNING]) -def test_locksroot_loading_during_channel_settle_handling( +@patch("raiden.message_handler.decrypt_secret", side_effect=InvalidSecret) +def test_locksroot_loading_during_channel_settle_handling( # pylint: disable=unused-argument + decrypt_patch: Mock, raiden_chain: List[RaidenService], restart_node: RestartNode, deploy_client: JSONRPCClient, diff --git a/raiden/tests/integration/test_send_queued_messages.py b/raiden/tests/integration/test_send_queued_messages.py index b730493f94..09f4de6d5b 100644 --- a/raiden/tests/integration/test_send_queued_messages.py +++ b/raiden/tests/integration/test_send_queued_messages.py @@ -1,8 +1,11 @@ +from unittest.mock import Mock, patch + import gevent import pytest from raiden import waiting from raiden.constants import RoutingMode +from raiden.exceptions import InvalidSecret from raiden.message_handler import MessageHandler from raiden.network.transport import MatrixTransport from raiden.raiden_event_handler import RaidenEventHandler @@ -164,7 +167,9 @@ def test_send_queued_messages_after_restart( # pylint: disable=unused-argument @pytest.mark.parametrize("number_of_nodes", [2]) @pytest.mark.parametrize("channels_per_node", [1]) @pytest.mark.parametrize("number_of_tokens", [1]) -def test_payment_statuses_are_restored( +@patch("raiden.message_handler.decrypt_secret", side_effect=InvalidSecret) +def test_payment_statuses_are_restored( # pylint: disable=unused-argument + decrypt_patch: Mock, raiden_network: List[RaidenService], restart_node: RestartNode, token_addresses: List[TokenAddress], diff --git a/raiden/tests/integration/transfer/test_mediatedtransfer.py b/raiden/tests/integration/transfer/test_mediatedtransfer.py index 2113b2f8ab..0f81056c25 100644 --- a/raiden/tests/integration/transfer/test_mediatedtransfer.py +++ b/raiden/tests/integration/transfer/test_mediatedtransfer.py @@ -1,9 +1,10 @@ from typing import List, cast -from unittest.mock import patch +from unittest.mock import Mock, patch +import gevent import pytest -from raiden.exceptions import RaidenUnrecoverableError +from raiden.exceptions import InvalidSecret, RaidenUnrecoverableError from raiden.message_handler import MessageHandler from raiden.messages.transfers import LockedTransfer, RevealSecret, SecretRequest from raiden.network.pathfinding import PFSConfig, PFSInfo, PFSProxy @@ -16,8 +17,9 @@ from raiden.tests.utils.factories import make_secret from raiden.tests.utils.mediation_fees import get_amount_for_sending_before_and_after_fees from raiden.tests.utils.network import CHAIN -from raiden.tests.utils.protocol import WaitForMessage +from raiden.tests.utils.protocol import HoldRaidenEventHandler, WaitForMessage from raiden.tests.utils.transfer import ( + TransferState, assert_succeeding_transfer_invariants, assert_synced_channel_state, block_timeout_for_transfer_by_secrethash, @@ -27,6 +29,7 @@ wait_assert, ) from raiden.transfer import views +from raiden.transfer.mediated_transfer.events import SendSecretRequest from raiden.transfer.mediated_transfer.initiator import calculate_fee_margin from raiden.transfer.mediated_transfer.mediation_fee import FeeScheduleState from raiden.transfer.mediated_transfer.state_change import ActionInitMediator, ActionInitTarget @@ -49,6 +52,48 @@ from raiden.waiting import wait_for_block +@raise_on_failure +@pytest.mark.parametrize("channels_per_node", [CHAIN]) +@pytest.mark.parametrize("number_of_nodes", [2]) +def test_transfer_with_secret( + raiden_network: List[RaidenService], number_of_nodes, deposit, token_addresses, network_wait +): + app0, app1 = raiden_network + token_address = token_addresses[0] + chain_state = views.state_from_raiden(app0) + token_network_registry_address = app0.default_registry.address + token_network_address = views.get_token_network_address_by_token_address( + chain_state, token_network_registry_address, token_address + ) + + amount = PaymentAmount(10) + secret_hash = transfer( + initiator_app=app0, + target_app=app1, + token_address=token_address, + amount=amount, + transfer_state=TransferState.LOCKED, + identifier=PaymentID(1), + timeout=network_wait * number_of_nodes, + routes=[[app0, app1]], + ) + + assert isinstance(app1.raiden_event_handler, HoldRaidenEventHandler) + app1.raiden_event_handler.hold(SendSecretRequest, {"secrethash": secret_hash}) + + with gevent.Timeout(20): + wait_assert( + assert_succeeding_transfer_invariants, + token_network_address, + app0, + deposit - amount, + [], + app1, + deposit + amount, + [], + ) + + @raise_on_failure @pytest.mark.parametrize("channels_per_node", [CHAIN]) @pytest.mark.parametrize("number_of_nodes", [3]) @@ -440,8 +485,14 @@ def test_mediated_transfer_calls_pfs( @raise_on_failure @pytest.mark.parametrize("channels_per_node", [CHAIN]) @pytest.mark.parametrize("number_of_nodes", [3]) +@patch("raiden.message_handler.decrypt_secret", side_effect=InvalidSecret) def test_mediated_transfer_with_node_consuming_more_than_allocated_fee( - raiden_network: List[RaidenService], number_of_nodes, deposit, token_addresses, network_wait + decrypt_patch: Mock, + raiden_network: List[RaidenService], + number_of_nodes, + deposit, + token_addresses, + network_wait, ): """ Tests a mediator node consuming more fees than allocated. diff --git a/raiden/tests/integration/transfer/test_mediatedtransfer_events.py b/raiden/tests/integration/transfer/test_mediatedtransfer_events.py index a992c6f410..6a5e6399c7 100644 --- a/raiden/tests/integration/transfer/test_mediatedtransfer_events.py +++ b/raiden/tests/integration/transfer/test_mediatedtransfer_events.py @@ -1,5 +1,8 @@ +from unittest.mock import patch + import pytest +from raiden.exceptions import InvalidSecret from raiden.tests.utils.detect_failure import raise_on_failure from raiden.tests.utils.events import search_for_item from raiden.tests.utils.network import CHAIN @@ -22,15 +25,16 @@ def test_mediated_transfer_events(raiden_network, number_of_nodes, token_address token_address = token_addresses[0] amount = 10 - transfer( - initiator_app=app0, - target_app=app2, - token_address=token_address, - amount=PaymentAmount(amount), - identifier=PaymentID(1), - timeout=network_wait * number_of_nodes, - routes=[[app0, app1, app2]], - ) + with patch("raiden.message_handler.decrypt_secret", side_effect=InvalidSecret): + transfer( + initiator_app=app0, + target_app=app2, + token_address=token_address, + amount=PaymentAmount(amount), + identifier=PaymentID(1), + timeout=network_wait * number_of_nodes, + routes=[[app0, app1, app2]], + ) def test_initiator_events(): assert not has_unlock_failure(app0) diff --git a/raiden/tests/unit/test_channelstate.py b/raiden/tests/unit/test_channelstate.py index 9677b2c2b9..04f34c6f54 100644 --- a/raiden/tests/unit/test_channelstate.py +++ b/raiden/tests/unit/test_channelstate.py @@ -1224,7 +1224,7 @@ def test_channelstate_unlock_without_locks(): block_number=77, block_hash=make_block_hash(), ) - iteration = channel.handle_channel_closed(channel_state, state_change) + iteration = channel._handle_channel_closed(state_change, channel_state) assert not iteration.events @@ -1313,7 +1313,7 @@ def test_channelstate_unlock_unlocked_onchain(): block_number=closed_block_number, block_hash=closed_block_hash, ) - iteration = channel.handle_channel_closed(channel_state, close_state_change) + iteration = channel._handle_channel_closed(close_state_change, channel_state) assert search_for_item(iteration.events, ContractSendChannelBatchUnlock, {}) is None settle_block_number = lock_expiration + channel_state.reveal_timeout + 1 @@ -1329,7 +1329,7 @@ def test_channelstate_unlock_unlocked_onchain(): our_onchain_locksroot=LOCKSROOT_OF_NO_LOCKS, ) - iteration = channel.handle_channel_settled(channel_state, settle_state_change) + iteration = channel._handle_channel_settled(settle_state_change, channel_state) assert search_for_item(iteration.events, ContractSendChannelBatchUnlock, {}) is not None @@ -1424,7 +1424,7 @@ def test_update_must_be_called_if_close_lost_race(): block_number=77, block_hash=make_block_hash(), ) - iteration = channel.handle_channel_closed(channel_state, state_change) + iteration = channel._handle_channel_closed(state_change, channel_state) assert search_for_item(iteration.events, ContractSendChannelUpdateTransfer, {}) is not None @@ -1461,7 +1461,7 @@ def test_update_transfer(): block_number=closed_block_number, block_hash=closed_block_hash, ) - iteration2 = channel.handle_channel_closed(channel_state, channel_close_state_change) + iteration2 = channel._handle_channel_closed(channel_close_state_change, channel_state) # update_transaction in channel state should not be set because there was no transfer channel_state = iteration2.new_state @@ -1476,8 +1476,8 @@ def test_update_transfer(): ) update_block_number = 20 - iteration3 = channel.handle_channel_updated_transfer( - channel_state, update_transfer_state_change, update_block_number + iteration3 = channel._handle_channel_updated_transfer( + update_transfer_state_change, channel_state, update_block_number ) # now update_transaction in channel state should be set @@ -1591,9 +1591,9 @@ def test_action_withdraw(): canonical_identifier=channel_state.canonical_identifier, total_withdraw=100 ) - iteration = channel.handle_action_withdraw( + iteration = channel._handle_action_withdraw( channel_state=channel_state, - action_withdraw=action_withdraw, + action=action_withdraw, pseudo_random_generator=pseudo_random_generator, block_number=2, ) @@ -1608,9 +1608,9 @@ def test_action_withdraw(): canonical_identifier=channel_state.canonical_identifier, total_withdraw=our_balance ) - iteration = channel.handle_action_withdraw( + iteration = channel._handle_action_withdraw( channel_state=channel_state, - action_withdraw=action_withdraw, + action=action_withdraw, pseudo_random_generator=pseudo_random_generator, block_number=3, ) @@ -1626,9 +1626,9 @@ def test_action_withdraw(): canonical_identifier=channel_state.canonical_identifier, total_withdraw=our_balance ) - iteration = channel.handle_action_withdraw( + iteration = channel._handle_action_withdraw( channel_state=iteration.new_state, - action_withdraw=action_withdraw, + action=action_withdraw, pseudo_random_generator=pseudo_random_generator, block_number=4, ) @@ -1664,8 +1664,8 @@ def test_receive_withdraw_request(): expiration=expiration, ) - iteration = channel.handle_receive_withdraw_request( - channel_state=channel_state, withdraw_request=withdraw_request + iteration = channel._handle_receive_withdraw_request( + channel_state=channel_state, action=withdraw_request ) assert ( @@ -1698,8 +1698,8 @@ def test_receive_withdraw_request(): expiration=expiration, ) - iteration = channel.handle_receive_withdraw_request( - channel_state=channel_state, withdraw_request=withdraw_request + iteration = channel._handle_receive_withdraw_request( + channel_state=channel_state, action=withdraw_request ) # pylint: disable=no-member @@ -1724,8 +1724,8 @@ def test_receive_withdraw_request(): expiration=10, ) - iteration = channel.handle_receive_withdraw_request( - channel_state=iteration.new_state, withdraw_request=withdraw_request + iteration = channel._handle_receive_withdraw_request( + channel_state=iteration.new_state, action=withdraw_request ) assert ( @@ -1749,8 +1749,8 @@ def test_receive_withdraw_request(): expiration=10, ) - iteration = channel.handle_receive_withdraw_request( - channel_state=iteration.new_state, withdraw_request=withdraw_request + iteration = channel._handle_receive_withdraw_request( + channel_state=iteration.new_state, action=withdraw_request ) assert ( @@ -1803,9 +1803,9 @@ def test_receive_withdraw_confirmation(): expiration=expiration_block, ) - iteration = channel.handle_receive_withdraw_confirmation( + iteration = channel._handle_receive_withdraw_confirmation( channel_state=channel_state, - withdraw=receive_withdraw, + action=receive_withdraw, block_number=10, block_hash=block_hash, ) @@ -1830,9 +1830,9 @@ def test_receive_withdraw_confirmation(): expiration=expiration_block, ) - iteration = channel.handle_receive_withdraw_confirmation( + iteration = channel._handle_receive_withdraw_confirmation( channel_state=iteration.new_state, - withdraw=receive_withdraw, + action=receive_withdraw, block_number=10, block_hash=block_hash, ) @@ -1855,9 +1855,9 @@ def test_receive_withdraw_confirmation(): expiration=expiration_block, ) - iteration = channel.handle_receive_withdraw_confirmation( + iteration = channel._handle_receive_withdraw_confirmation( channel_state=iteration.new_state, - withdraw=receive_withdraw, + action=receive_withdraw, block_number=10, block_hash=block_hash, ) @@ -1891,9 +1891,9 @@ def test_node_sends_withdraw_expiry(): block_hash = make_transaction_hash() block = Block(block_number=expiration_threshold - 1, gas_limit=1, block_hash=block_hash) - iteration = channel.handle_block( + iteration = channel._handle_block( channel_state=channel_state, - state_change=block, + action=block, block_number=expiration_threshold - 1, pseudo_random_generator=pseudo_random_generator, ) @@ -1903,9 +1903,9 @@ def test_node_sends_withdraw_expiry(): block_hash = make_transaction_hash() block = Block(block_number=expiration_threshold, gas_limit=1, block_hash=block_hash) - iteration = channel.handle_block( + iteration = channel._handle_block( channel_state=channel_state, - state_change=block, + action=block, block_number=expiration_threshold, pseudo_random_generator=pseudo_random_generator, ) @@ -2210,8 +2210,8 @@ def test_receive_contract_withdraw(): fee_config=MediationFeeConfig(), ) - iteration = channel.handle_channel_withdraw( - channel_state=channel_state, state_change=contract_receive_withdraw + iteration = channel._handle_channel_withdraw( + channel_state=channel_state, action=contract_receive_withdraw ) assert iteration.new_state.our_state.offchain_total_withdraw == 0 @@ -2232,8 +2232,8 @@ def test_receive_contract_withdraw(): fee_config=MediationFeeConfig(), ) - iteration = channel.handle_channel_withdraw( - channel_state=iteration.new_state, state_change=contract_receive_withdraw + iteration = channel._handle_channel_withdraw( + channel_state=iteration.new_state, action=contract_receive_withdraw ) assert iteration.new_state.partner_state.offchain_total_withdraw == 0 diff --git a/raiden/tests/unit/transfer/test_channel.py b/raiden/tests/unit/transfer/test_channel.py index f450aa338e..f8cc708653 100644 --- a/raiden/tests/unit/transfer/test_channel.py +++ b/raiden/tests/unit/transfer/test_channel.py @@ -21,8 +21,8 @@ compute_locksroot, get_secret, get_status, - handle_block, handle_receive_lockedtransfer, + handle_state_transitions, is_balance_proof_usable_onchain, is_valid_balanceproof_signature, set_settled, @@ -357,20 +357,22 @@ def test_handle_block_closed_channel(): ) pseudo_random_generator = random.Random() block = Block(block_number=90, gas_limit=100000, block_hash=factories.make_block_hash()) - before_settle = handle_block( + before_settle = handle_state_transitions( + block, channel_state=channel_state, - state_change=block, block_number=block.block_number, + block_hash=None, pseudo_random_generator=pseudo_random_generator, ) assert get_status(before_settle.new_state) == ChannelState.STATE_CLOSED assert not before_settle.events block = Block(block_number=102, gas_limit=100000, block_hash=factories.make_block_hash()) - after_settle = handle_block( + after_settle = handle_state_transitions( + block, channel_state=before_settle.new_state, - state_change=block, block_number=block.block_number, + block_hash=None, pseudo_random_generator=pseudo_random_generator, ) assert get_status(after_settle.new_state) == ChannelState.STATE_SETTLING diff --git a/raiden/tests/utils/transfer.py b/raiden/tests/utils/transfer.py index 1005d1e77a..0c537e624d 100644 --- a/raiden/tests/utils/transfer.py +++ b/raiden/tests/utils/transfer.py @@ -95,6 +95,7 @@ class TransferState(Enum): """Represents the target state of a transfer.""" + LOCKED = "locked" UNLOCKED = "unlocked" EXPIRED = "expired" SECRET_NOT_REVEALED = "secret_not_revealed" @@ -244,6 +245,16 @@ def transfer( timeout=timeout, route_states=route_states, ) + elif transfer_state is TransferState.LOCKED: + return _transfer_locked( + initiator_app=initiator_app, + target_app=target_app, + token_address=token_address, + amount=amount, + identifier=identifier, + timeout=timeout, + route_states=route_states, + ) else: raise RuntimeError("Type of transfer not implemented.") @@ -392,6 +403,40 @@ def _transfer_secret_not_requested( return secrethash +def _transfer_locked( + initiator_app: RaidenService, + target_app: RaidenService, + token_address: TokenAddress, + amount: PaymentAmount, + identifier: PaymentID, + timeout: Optional[float] = None, + route_states: List[RouteState] = None, +) -> SecretHash: + if timeout is None: + timeout = 10 + + secret, secrethash = make_secret_with_hash() + + token_network_registry_address = initiator_app.default_registry.address + token_network_address = views.get_token_network_address_by_token_address( + chain_state=views.state_from_raiden(initiator_app), + token_network_registry_address=token_network_registry_address, + token_address=token_address, + ) + assert token_network_address is not None + initiator_app.mediated_transfer_async( + token_network_address=token_network_address, + amount=amount, + target=TargetAddress(target_app.address), + identifier=identifier, + secret=secret, + secrethash=secrethash, + route_states=route_states, + ) + + return secrethash + + def transfer_and_assert_path( path: List[RaidenService], token_address: TokenAddress, diff --git a/raiden/transfer/channel.py b/raiden/transfer/channel.py index d2db434be6..e808301824 100644 --- a/raiden/transfer/channel.py +++ b/raiden/transfer/channel.py @@ -1,6 +1,7 @@ -# pylint: disable=too-many-lines +# pylint: disable=too-many-lines,unused-argument import random from enum import Enum +from functools import singledispatch from typing import TYPE_CHECKING from eth_utils import encode_hex, keccak, to_hex @@ -79,7 +80,7 @@ ReceiveWithdrawExpired, ReceiveWithdrawRequest, ) -from raiden.transfer.utils import FuncMap, hash_balance_data +from raiden.transfer.utils import hash_balance_data from raiden.utils.formatting import to_checksum_address from raiden.utils.packing import pack_balance_proof, pack_withdraw from raiden.utils.signer import recover @@ -1853,14 +1854,27 @@ def register_onchain_secret( ) -def handle_action_close( +@singledispatch +def handle_state_transitions( + action: StateChange, channel_state: NettingChannelState, - close: ActionChannelClose, block_number: BlockNumber, block_hash: BlockHash, + pseudo_random_generator: random.Random, +) -> TransitionResult[Optional[NettingChannelState]]: + return TransitionResult(channel_state, []) + + +@handle_state_transitions.register +def _handle_action_close( + action: ActionChannelClose, + channel_state: NettingChannelState, + block_number: BlockNumber, + block_hash: BlockHash, + **kwargs: Optional[Dict[Any, Any]], ) -> TransitionResult[NettingChannelState]: msg = "caller must make sure the ids match" - assert channel_state.identifier == close.channel_identifier, msg + assert channel_state.identifier == action.channel_identifier, msg events = events_for_close( channel_state=channel_state, block_number=block_number, block_hash=block_hash @@ -1868,37 +1882,40 @@ def handle_action_close( return TransitionResult(channel_state, events) -def handle_action_withdraw( +@handle_state_transitions.register +def _handle_action_withdraw( + action: ActionChannelWithdraw, channel_state: NettingChannelState, - action_withdraw: ActionChannelWithdraw, pseudo_random_generator: random.Random, block_number: BlockNumber, + **kwargs: Optional[Dict[Any, Any]], ) -> TransitionResult[NettingChannelState]: events: List[Event] = list() - is_valid_withdraw = is_valid_action_withdraw(channel_state, action_withdraw) + is_valid_withdraw = is_valid_action_withdraw(channel_state, action) if is_valid_withdraw: events = send_withdraw_request( channel_state=channel_state, - total_withdraw=action_withdraw.total_withdraw, + total_withdraw=action.total_withdraw, block_number=block_number, pseudo_random_generator=pseudo_random_generator, - recipient_metadata=action_withdraw.recipient_metadata, + recipient_metadata=action.recipient_metadata, ) else: error_msg = is_valid_withdraw.as_error_message - assert error_msg, "is_valid_action_withdraw should return error msg if not valid" + assert error_msg, "is_valid_action should return error msg if not valid" events = [ - EventInvalidActionWithdraw( - attempted_withdraw=action_withdraw.total_withdraw, reason=error_msg - ) + EventInvalidActionWithdraw(attempted_withdraw=action.total_withdraw, reason=error_msg) ] return TransitionResult(channel_state, events) -def handle_action_set_reveal_timeout( - channel_state: NettingChannelState, state_change: ActionChannelSetRevealTimeout +@handle_state_transitions.register +def _handle_action_set_reveal_timeout( + action: ActionChannelSetRevealTimeout, + channel_state: NettingChannelState, + **kwargs: Optional[Dict[Any, Any]], ) -> TransitionResult[NettingChannelState]: events: List[Event] = list() @@ -1907,48 +1924,49 @@ def handle_action_set_reveal_timeout( # + the min amount of blocks expected for that transaction to be mined according # to fastest gas strategy where a transaction takes roughly 60 seconds to # be mined, which is roughly equal to 7 blocks. - state_change.reveal_timeout >= 7 - and channel_state.settle_timeout >= state_change.reveal_timeout * 2 + action.reveal_timeout >= 7 + and channel_state.settle_timeout >= action.reveal_timeout * 2 ) if is_valid_reveal_timeout: - channel_state.reveal_timeout = state_change.reveal_timeout + channel_state.reveal_timeout = action.reveal_timeout else: error_msg = "Settle timeout should be at least twice as large as reveal timeout" events = [ EventInvalidActionSetRevealTimeout( - reveal_timeout=state_change.reveal_timeout, reason=error_msg + reveal_timeout=action.reveal_timeout, reason=error_msg ) ] return TransitionResult(channel_state, events) -def handle_receive_withdraw_request( - channel_state: NettingChannelState, withdraw_request: ReceiveWithdrawRequest +@handle_state_transitions.register +def _handle_receive_withdraw_request( + action: ReceiveWithdrawRequest, + channel_state: NettingChannelState, + **kwargs: Optional[Dict[Any, Any]], ) -> TransitionResult[NettingChannelState]: - is_valid = is_valid_withdraw_request( - channel_state=channel_state, withdraw_request=withdraw_request - ) + is_valid = is_valid_withdraw_request(channel_state=channel_state, withdraw_request=action) if is_valid: withdraw_state = PendingWithdrawState( - total_withdraw=withdraw_request.total_withdraw, - nonce=withdraw_request.nonce, - expiration=withdraw_request.expiration, - recipient_metadata=withdraw_request.sender_metadata, + total_withdraw=action.total_withdraw, + nonce=action.nonce, + expiration=action.expiration, + recipient_metadata=action.sender_metadata, ) channel_state.partner_state.withdraws_pending[ withdraw_state.total_withdraw ] = withdraw_state - channel_state.partner_state.nonce = withdraw_request.nonce + channel_state.partner_state.nonce = action.nonce channel_state.our_state.nonce = get_next_nonce(channel_state.our_state) send_withdraw = SendWithdrawConfirmation( canonical_identifier=channel_state.canonical_identifier, recipient=channel_state.partner_state.address, - recipient_metadata=withdraw_request.sender_metadata, - message_identifier=withdraw_request.message_identifier, - total_withdraw=withdraw_request.total_withdraw, + recipient_metadata=action.sender_metadata, + message_identifier=action.message_identifier, + total_withdraw=action.total_withdraw, participant=channel_state.partner_state.address, nonce=channel_state.our_state.nonce, expiration=withdraw_state.expiration, @@ -1959,48 +1977,50 @@ def handle_receive_withdraw_request( error_msg = is_valid.as_error_message assert error_msg, "is_valid_withdraw_request should return error msg if not valid" invalid_withdraw_request = EventInvalidReceivedWithdrawRequest( - attempted_withdraw=withdraw_request.total_withdraw, reason=error_msg + attempted_withdraw=action.total_withdraw, reason=error_msg ) events = [invalid_withdraw_request] return TransitionResult(channel_state, events) -def handle_receive_withdraw_confirmation( +@handle_state_transitions.register +def _handle_receive_withdraw_confirmation( + action: ReceiveWithdrawConfirmation, channel_state: NettingChannelState, - withdraw: ReceiveWithdrawConfirmation, block_number: BlockNumber, block_hash: BlockHash, + **kwargs: Optional[Dict[Any, Any]], ) -> TransitionResult[NettingChannelState]: is_valid = is_valid_withdraw_confirmation( - channel_state=channel_state, received_withdraw=withdraw + channel_state=channel_state, received_withdraw=action ) - withdraw_state = channel_state.our_state.withdraws_pending.get(withdraw.total_withdraw) + withdraw_state = channel_state.our_state.withdraws_pending.get(action.total_withdraw) recipient_metadata = None if withdraw_state is not None: recipient_metadata = withdraw_state.recipient_metadata events: List[Event] if is_valid: - channel_state.partner_state.nonce = withdraw.nonce + channel_state.partner_state.nonce = action.nonce events = [ SendProcessed( recipient=channel_state.partner_state.address, recipient_metadata=recipient_metadata, - message_identifier=withdraw.message_identifier, + message_identifier=action.message_identifier, canonical_identifier=CANONICAL_IDENTIFIER_UNORDERED_QUEUE, ) ] # Only send the transaction on-chain if there is enough time for the # withdraw transaction to be mined - if withdraw.expiration >= block_number - channel_state.reveal_timeout: + if action.expiration >= block_number - channel_state.reveal_timeout: withdraw_on_chain = ContractSendChannelWithdraw( - canonical_identifier=withdraw.canonical_identifier, - total_withdraw=withdraw.total_withdraw, - partner_signature=withdraw.signature, - expiration=withdraw.expiration, + canonical_identifier=action.canonical_identifier, + total_withdraw=action.total_withdraw, + partner_signature=action.signature, + expiration=action.expiration, triggered_by_block_hash=block_hash, ) events.append(withdraw_on_chain) @@ -2008,34 +2028,34 @@ def handle_receive_withdraw_confirmation( error_msg = is_valid.as_error_message assert error_msg, "is_valid_withdraw_confirmation should return error msg if not valid" invalid_withdraw = EventInvalidReceivedWithdraw( - attempted_withdraw=withdraw.total_withdraw, reason=error_msg + attempted_withdraw=action.total_withdraw, reason=error_msg ) events = [invalid_withdraw] return TransitionResult(channel_state, events) -def handle_receive_withdraw_expired( +@handle_state_transitions.register +def _handle_receive_withdraw_expired( + action: ReceiveWithdrawExpired, channel_state: NettingChannelState, - withdraw_expired: ReceiveWithdrawExpired, block_number: BlockNumber, + **kwargs: Optional[Dict[Any, Any]], ) -> TransitionResult: events: List[Event] = list() - withdraw_state = channel_state.partner_state.withdraws_pending.get( - withdraw_expired.total_withdraw - ) + withdraw_state = channel_state.partner_state.withdraws_pending.get(action.total_withdraw) if not withdraw_state: invalid_withdraw_expired_msg = ( - f"Withdraw expired of {withdraw_expired.total_withdraw} " + f"Withdraw expired of {action.total_withdraw} " f"did not correspond to previous withdraw request" ) return TransitionResult( channel_state, [ EventInvalidReceivedWithdrawExpired( - attempted_withdraw=withdraw_expired.total_withdraw, + attempted_withdraw=action.total_withdraw, reason=invalid_withdraw_expired_msg, ) ], @@ -2043,19 +2063,19 @@ def handle_receive_withdraw_expired( is_valid = is_valid_withdraw_expired( channel_state=channel_state, - state_change=withdraw_expired, + state_change=action, withdraw_state=withdraw_state, block_number=block_number, ) if is_valid: del channel_state.partner_state.withdraws_pending[withdraw_state.total_withdraw] - channel_state.partner_state.nonce = withdraw_expired.nonce + channel_state.partner_state.nonce = action.nonce send_processed = SendProcessed( recipient=channel_state.partner_state.address, recipient_metadata=withdraw_state.recipient_metadata, - message_identifier=withdraw_expired.message_identifier, + message_identifier=action.message_identifier, canonical_identifier=CANONICAL_IDENTIFIER_UNORDERED_QUEUE, ) events = [send_processed] @@ -2063,7 +2083,7 @@ def handle_receive_withdraw_expired( error_msg = is_valid.as_error_message assert error_msg, "is_valid_withdraw_expired should return error msg if not valid" invalid_withdraw_expired = EventInvalidReceivedWithdrawExpired( - attempted_withdraw=withdraw_expired.total_withdraw, reason=error_msg + attempted_withdraw=action.total_withdraw, reason=error_msg ) events = [invalid_withdraw_expired] @@ -2231,13 +2251,15 @@ def handle_unlock( return is_valid, events, msg -def handle_block( +@handle_state_transitions.register +def _handle_block( + action: Block, channel_state: NettingChannelState, - state_change: Block, block_number: BlockNumber, pseudo_random_generator: random.Random, + **kwargs: Optional[Dict[Any, Any]], ) -> TransitionResult[NettingChannelState]: - assert state_change.block_number == block_number, "Block number mismatch" + assert action.block_number == block_number, "Block number mismatch" events: List[Event] = list() @@ -2258,50 +2280,53 @@ def handle_block( closed_block_number = channel_state.close_transaction.finished_block_number settlement_end = closed_block_number + channel_state.settle_timeout - if state_change.block_number > settlement_end: + if action.block_number > settlement_end: channel_state.settle_transaction = TransactionExecutionStatus( - state_change.block_number, None, None + action.block_number, None, None ) event = ContractSendChannelSettle( canonical_identifier=channel_state.canonical_identifier, - triggered_by_block_hash=state_change.block_hash, + triggered_by_block_hash=action.block_hash, ) events.append(event) return TransitionResult(channel_state, events) -def handle_channel_closed( - channel_state: NettingChannelState, state_change: ContractReceiveChannelClosed +@handle_state_transitions.register +def _handle_channel_closed( + action: ContractReceiveChannelClosed, + channel_state: NettingChannelState, + **kwargs: Optional[Dict[Any, Any]], ) -> TransitionResult[NettingChannelState]: events: List[Event] = list() just_closed = ( - state_change.channel_identifier == channel_state.identifier + action.channel_identifier == channel_state.identifier and get_status(channel_state) in CHANNEL_STATES_PRIOR_TO_CLOSED ) if just_closed: - set_closed(channel_state, state_change.block_number) + set_closed(channel_state, action.block_number) balance_proof = channel_state.partner_state.balance_proof call_update = ( - state_change.transaction_from != channel_state.our_state.address + action.transaction_from != channel_state.our_state.address and balance_proof is not None and channel_state.update_transaction is None ) if call_update: - expiration = BlockExpiration(state_change.block_number + channel_state.settle_timeout) + expiration = BlockExpiration(action.block_number + channel_state.settle_timeout) assert isinstance(balance_proof, BalanceProofSignedState), MYPY_ANNOTATION # The channel was closed by our partner, if there is a balance # proof available update this node half of the state update = ContractSendChannelUpdateTransfer( expiration=expiration, balance_proof=balance_proof, - triggered_by_block_hash=state_change.block_hash, + triggered_by_block_hash=action.block_hash, ) channel_state.update_transaction = TransactionExecutionStatus( - started_block_number=state_change.block_number, + started_block_number=action.block_number, finished_block_number=None, result=None, ) @@ -2310,12 +2335,14 @@ def handle_channel_closed( return TransitionResult(channel_state, events) -def handle_channel_updated_transfer( +@handle_state_transitions.register +def _handle_channel_updated_transfer( + action: ContractReceiveUpdateTransfer, channel_state: NettingChannelState, - state_change: ContractReceiveUpdateTransfer, block_number: BlockNumber, + **kwargs: Optional[Dict[Any, Any]], ) -> TransitionResult[NettingChannelState]: - if state_change.channel_identifier == channel_state.identifier: + if action.channel_identifier == channel_state.identifier: # update transfer was called, make sure we don't call it again channel_state.update_transaction = TransactionExecutionStatus( started_block_number=None, @@ -2326,16 +2353,19 @@ def handle_channel_updated_transfer( return TransitionResult(channel_state, list()) -def handle_channel_settled( - channel_state: NettingChannelState, state_change: ContractReceiveChannelSettled +@handle_state_transitions.register +def _handle_channel_settled( + action: ContractReceiveChannelSettled, + channel_state: NettingChannelState, + **kwargs: Optional[Dict[Any, Any]], ) -> TransitionResult[Optional[NettingChannelState]]: events: List[Event] = list() - if state_change.channel_identifier == channel_state.identifier: - set_settled(channel_state, state_change.block_number) + if action.channel_identifier == channel_state.identifier: + set_settled(channel_state, action.block_number) - our_locksroot = state_change.our_onchain_locksroot - partner_locksroot = state_change.partner_onchain_locksroot + our_locksroot = action.our_onchain_locksroot + partner_locksroot = action.partner_onchain_locksroot should_clear_channel = ( our_locksroot == LOCKSROOT_OF_NO_LOCKS and partner_locksroot == LOCKSROOT_OF_NO_LOCKS @@ -2350,7 +2380,7 @@ def handle_channel_settled( onchain_unlock = ContractSendChannelBatchUnlock( canonical_identifier=channel_state.canonical_identifier, sender=channel_state.partner_state.address, - triggered_by_block_hash=state_change.block_hash, + triggered_by_block_hash=action.block_hash, ) events.append(onchain_unlock) @@ -2377,11 +2407,14 @@ def update_fee_schedule_after_balance_change( return [] -def handle_channel_deposit( - channel_state: NettingChannelState, state_change: ContractReceiveChannelDeposit +@handle_state_transitions.register +def _handle_channel_deposit( + action: ContractReceiveChannelDeposit, + channel_state: NettingChannelState, + **kwargs: Optional[Dict[Any, Any]], ) -> TransitionResult[NettingChannelState]: - participant_address = state_change.deposit_transaction.participant_address - contract_balance = Balance(state_change.deposit_transaction.contract_balance) + participant_address = action.deposit_transaction.participant_address + contract_balance = Balance(action.deposit_transaction.contract_balance) if participant_address == channel_state.our_state.address: update_contract_balance(channel_state.our_state, contract_balance) @@ -2389,39 +2422,45 @@ def handle_channel_deposit( update_contract_balance(channel_state.partner_state, contract_balance) # A deposit changes the total capacity of the channel and as such the fees need to change - update_fee_schedule_after_balance_change(channel_state, state_change.fee_config) + update_fee_schedule_after_balance_change(channel_state, action.fee_config) return TransitionResult(channel_state, []) -def handle_channel_withdraw( - channel_state: NettingChannelState, state_change: ContractReceiveChannelWithdraw +@handle_state_transitions.register +def _handle_channel_withdraw( + action: ContractReceiveChannelWithdraw, + channel_state: NettingChannelState, + **kwargs: Optional[Dict[Any, Any]], ) -> TransitionResult[NettingChannelState]: """An on-chain total withdraw took place which means that we have to keep track of this not to go lower than the on-chain value. The value is set to onchain_total_withdraw and the corresponding withdraw_state is cleared. """ participants = (channel_state.our_state.address, channel_state.partner_state.address) - if state_change.participant not in participants: + if action.participant not in participants: return TransitionResult(channel_state, list()) - if state_change.participant == channel_state.our_state.address: + if action.participant == channel_state.our_state.address: end_state = channel_state.our_state else: end_state = channel_state.partner_state - withdraw_state = end_state.withdraws_pending.get(state_change.total_withdraw) + withdraw_state = end_state.withdraws_pending.get(action.total_withdraw) if withdraw_state: - del end_state.withdraws_pending[state_change.total_withdraw] + del end_state.withdraws_pending[action.total_withdraw] - end_state.onchain_total_withdraw = state_change.total_withdraw + end_state.onchain_total_withdraw = action.total_withdraw # A withdraw changes the total capacity of the channel and as such the fees need to change - update_fee_schedule_after_balance_change(channel_state, state_change.fee_config) + update_fee_schedule_after_balance_change(channel_state, action.fee_config) return TransitionResult(channel_state, []) -def handle_channel_batch_unlock( - channel_state: NettingChannelState, state_change: ContractReceiveChannelBatchUnlock +@handle_state_transitions.register +def _handle_channel_batch_unlock( + action: ContractReceiveChannelBatchUnlock, + channel_state: NettingChannelState, + **kwargs: Optional[Dict[Any, Any]], ) -> TransitionResult[Optional[NettingChannelState]]: events: List[Event] = list() @@ -2434,9 +2473,9 @@ def handle_channel_batch_unlock( partner_state = channel_state.partner_state # partner is the address of the sender - if state_change.sender == our_state.address: + if action.sender == our_state.address: our_state.onchain_locksroot = Locksroot(LOCKSROOT_OF_NO_LOCKS) - elif state_change.sender == partner_state.address: + elif action.sender == partner_state.address: partner_state.onchain_locksroot = Locksroot(LOCKSROOT_OF_NO_LOCKS) # only clear the channel state once all unlocks have been done @@ -2543,91 +2582,14 @@ def state_transition( ) -> TransitionResult[Optional[NettingChannelState]]: # pragma: no unittest # pylint: disable=too-many-branches,unidiomatic-typecheck - events: List[Event] = list() - iteration: TransitionResult[Optional[NettingChannelState]] = TransitionResult( - channel_state, events + iteration: TransitionResult[Optional[NettingChannelState]] = handle_state_transitions( + state_change, + channel_state=channel_state, + block_number=block_number, + block_hash=block_hash, + pseudo_random_generator=pseudo_random_generator, ) - transition_map: Dict[Any, FuncMap] = { - Block: FuncMap( - handle_block, (channel_state, state_change, block_number, pseudo_random_generator), {} - ), - ActionChannelClose: FuncMap( - handle_action_close, - (), - dict( - channel_state=channel_state, - close=state_change, - block_number=block_number, - block_hash=block_hash, - ), - ), - ActionChannelWithdraw: FuncMap( - handle_action_withdraw, - (), - dict( - channel_state=channel_state, - action_withdraw=state_change, - pseudo_random_generator=pseudo_random_generator, - block_number=block_number, - ), - ), - ActionChannelSetRevealTimeout: FuncMap( - handle_action_set_reveal_timeout, - (), - dict(channel_state=channel_state, state_change=state_change), - ), - ContractReceiveChannelClosed: FuncMap( - handle_channel_closed, (channel_state, state_change), {} - ), - ContractReceiveUpdateTransfer: FuncMap( - handle_channel_updated_transfer, (channel_state, state_change, block_number), {} - ), - ContractReceiveChannelSettled: FuncMap( - handle_channel_settled, (channel_state, state_change), {} - ), - ContractReceiveChannelDeposit: FuncMap( - handle_channel_deposit, (channel_state, state_change), {} - ), - ContractReceiveChannelBatchUnlock: FuncMap( - handle_channel_batch_unlock, (channel_state, state_change), {} - ), - ContractReceiveChannelWithdraw: FuncMap( - handle_channel_withdraw, - (), - dict(channel_state=channel_state, state_change=state_change), - ), - ReceiveWithdrawRequest: FuncMap( - handle_receive_withdraw_request, - (), - dict(channel_state=channel_state, withdraw_request=state_change), - ), - ReceiveWithdrawConfirmation: FuncMap( - handle_receive_withdraw_confirmation, - (), - dict( - channel_state=channel_state, - withdraw=state_change, - block_number=block_number, - block_hash=block_hash, - ), - ), - ReceiveWithdrawExpired: FuncMap( - handle_receive_withdraw_expired, - (), - dict( - channel_state=channel_state, - withdraw_expired=state_change, - block_number=block_number, - ), - ), - } - - t_state_change = type(state_change) - func_map = transition_map.get(t_state_change) - if func_map: - iteration = func_map.function(*func_map.args, **func_map.kwargs) # type: ignore - if iteration.new_state is not None: sanity_check(iteration.new_state) diff --git a/raiden/transfer/mediated_transfer/state_change.py b/raiden/transfer/mediated_transfer/state_change.py index 96c6280466..31e5fe896c 100644 --- a/raiden/transfer/mediated_transfer/state_change.py +++ b/raiden/transfer/mediated_transfer/state_change.py @@ -68,6 +68,7 @@ class ActionInitTarget(BalanceProofStateChange): from_hop: HopState transfer: LockedTransferSignedState + received_valid_secret: Optional[bool] = field(default=False) def __post_init__(self) -> None: super().__post_init__() diff --git a/raiden/transfer/mediated_transfer/target.py b/raiden/transfer/mediated_transfer/target.py index 206b7a1beb..096fc6584f 100644 --- a/raiden/transfer/mediated_transfer/target.py +++ b/raiden/transfer/mediated_transfer/target.py @@ -112,6 +112,8 @@ def handle_inittarget( # enforced by the nonce increasing sequentially, which is verified by # the handler handle_receive_lockedtransfer. target_state = TargetTransferState(from_hop, transfer) + if state_change.received_valid_secret: + return TransitionResult(target_state, channel_events) safe_to_wait = is_safe_to_wait( transfer.lock.expiration, channel_state.reveal_timeout, block_number diff --git a/raiden/transfer/utils.py b/raiden/transfer/utils.py index 82b209f7be..c859461182 100644 --- a/raiden/transfer/utils.py +++ b/raiden/transfer/utils.py @@ -83,7 +83,7 @@ def encrypt_secret( def decrypt_secret(encrypted_secret: EncryptedSecret, private_key: PrivateKey) -> Secret: try: - secret = Secret(decrypt(encrypted_secret, private_key)) + secret = Secret(decrypt(private_key, encrypted_secret)) except ValueError: raise InvalidSecret return secret