# Superswaps

In [None]:
#| default_exp superswap

In [None]:
#| export

import json, requests
from sugar.swap import build_super_swap_data, SuperSwapData, SuperSwapQuote, setup_planner, SuperSwapDataInput
from sugar.token import Token
from sugar.helpers import get_salt, serialize_ica_calls
from sugar.config import hyperlane_relay_url, hyperlane_relayers
from sugar.chains import get_async_chain_from_token, AsyncChain, AsyncOPChain
from typing import List, Dict, Any, Optional
from abc import ABC, abstractmethod

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

# TODO: remove this when domains are supported on all chains
async def get_domain(chain_id: int) -> int:
    # TODO: remove chain_id arg when all chains support domains
    domains_abi = [
        {
            "name": "domains",
            "type": "function",
            "stateMutability": "view",
            "inputs": [
                {
                    "name": "",
                    "type": "uint256"
                }
            ],
            "outputs": [
                {
                    "name": "domain",
                    "type": "uint256"
                }
            ]
        }
    ]
    async with AsyncOPChain() as op:
        contract = op.web3.eth.contract(address=op.settings.message_module_contract_addr, abi=domains_abi)
        domain = await contract.functions.domains(chain_id).call()
        # TODO: remove fallback to chain_id when all chains support domains
        return domain if domain != 0 else int(chain_id)

In [None]:
#| export

class SuperswapRelayer(ABC):
    @abstractmethod
    def share_calls(self, calls: List[dict], salt: str, commitment_dispatch_tx: str, origin_domain: int) -> None:
        """Share calls with the relayer."""
        pass


class HTTPSuperswapRelayer(SuperswapRelayer):
    """HTTP-based relayer implementation."""
    def share_calls(self, calls: List[dict], salt: str, commitment_dispatch_tx: str, origin_domain: int) -> None:
        """
        Share calls with private relayer.
        
        Args:
            calls: List of call data dictionaries
            salt: Hex string salt value
            commitment_dispatch_tx: Transaction hash string
            origin_domain: Domain number
        """
        body = json.dumps({
            'commitmentDispatchTx': commitment_dispatch_tx,
            'originDomain': origin_domain,
            'calls': calls,
            'salt': salt,
            'relayers': hyperlane_relayers
        })
        resp = requests.post(hyperlane_relay_url, headers={'Content-Type': 'application/json'}, data=body)
        print(f"Hyperlane response: {resp.status_code}: {resp.text}")
        if not resp.ok:
            response_text = resp.text
            error_msg = f"Failed to share calls with relayer: {resp.status_code} {response_text}"
            print(f"Error: {error_msg}")
            raise Exception(error_msg)
                

In [None]:
#| export

# TODO: add helper to inspect tx using https://explorer.hyperlane.xyz/?search

supported_chains = ["OP", "Lisk", "Uni"]  # Example chains, extend as needed

class AsyncSuperswap:
    def __init__(self, relayer: Optional[SuperswapRelayer] = None, chain_for_writes: Optional[AsyncChain] = None):
        self.chain_for_writes = chain_for_writes
        self.relayer = relayer or HTTPSuperswapRelayer()

    def check_chain_support(self, from_token: Token, to_token: Token) -> None:
        """Check if the given tokens are supported for superswap."""
        from_chain, to_chain = get_async_chain_from_token(from_token), get_async_chain_from_token(to_token)
        if from_chain.name not in supported_chains or to_chain.name not in supported_chains:
            raise ValueError(f"Superswap only supports {supported_chains}. Got {from_chain.name} -> {to_chain.name}")

    async def swap(self, from_token: Token, to_token: Token, amount: float, slippage: Optional[float] = None) -> str:
        self.check_chain_support(from_token, to_token)
        quote = await self.get_super_quote(from_token=from_token, to_token=to_token, amount_in=25)
        return await self.swap_from_quote(quote=quote, slippage=0.1)

    async def get_super_quote(self, from_token: Token, to_token: Token, amount_in: float) -> SuperSwapQuote:
        from_chain, to_chain = get_async_chain_from_token(from_token), get_async_chain_from_token(to_token)
        async with from_chain, to_chain:
            origin_bridge_token, destination_bridge_token = await from_chain.get_superswap_connector_token(), await to_chain.get_superswap_connector_token()

        # hard code op -> uni swap for now
        starts_with_bridge_token = from_token.token_address == origin_bridge_token.token_address
        ends_with_bridge_token = to_token.token_address == destination_bridge_token.token_address
        origin_quote, destination_quote = None, None

        # we only need origin quote if we don't start with oUSDT
        if not starts_with_bridge_token:
            async with from_chain:
                origin_quote = await from_chain.get_quote(from_token, origin_bridge_token, amount_in)
                assert origin_quote is not None, "No origin quote found"

        # we need destination quote if we don't end with oUSDT
        if not ends_with_bridge_token:
            # XX: get_quote expects amount_in in "normal" (not wei) format, adjust accordingly
            bridged_amount = float(origin_quote.amount_out / 10 ** origin_bridge_token.decimals) if not starts_with_bridge_token else amount_in
            async with to_chain:
                destination_quote = await to_chain.get_quote(destination_bridge_token, to_token, bridged_amount)

        return SuperSwapQuote(
            from_token=from_token,
            to_token=to_token,
            from_bridge_token=origin_bridge_token,
            to_bridge_token=destination_bridge_token,
            amount_in=amount_in,
            origin_quote=origin_quote,
            destination_quote=destination_quote
        )

    async def swap_from_quote(self, quote: SuperSwapQuote, slippage: float, salt: Optional[str] = None):
        self.check_chain_support(from_token, to_token)
        from_chain, to_chain = get_async_chain_from_token(quote.from_token), get_async_chain_from_token(quote.to_token)

        # assert from_chain.name == "OP" and to_chain .name == "Uni", "SuperSwap only supports OP -> Uni for now"

        origin_quote, bridged_amount, destination_quote = quote.origin_quote, quote.origin_quote.amount_out, quote.destination_quote

        async with from_chain, to_chain:
            if not from_chain.account: raise ValueError("Cannot superswap without an account. Please connect your wallet first.")
            # TODO: use chain.get_domain() when all chains support domains
            origin_domain = await get_domain(int(from_chain.id))
            destination_domain = await get_domain(int(to_chain.id))
            user_ica_address = await from_chain.get_remote_interchain_account(destination_domain)
            bridge_fee = await from_chain.get_bridge_fee(int(to_chain.id))
            xchain_fee = await from_chain.get_xchain_fee(destination_domain)
            total_fee = bridge_fee + xchain_fee if quote.to_token.token_address != quote.to_bridge_token.token_address else bridge_fee 

            swap_data = build_super_swap_data(SuperSwapDataInput(
                from_token=quote.from_token,
                to_token=quote.to_token,
                from_bridge_token=quote.from_bridge_token,
                to_bridge_token=quote.to_bridge_token,
                account=from_chain.account.address,
                user_ICA=user_ica_address,
                user_ICA_balance=await to_chain.get_user_ica_balance(user_ica_address),
                origin_domain=origin_domain,
                origin_bridge=from_chain.settings.bridge_contract_addr,
                origin_hook=await from_chain.get_ica_hook(),
                origin_ICA_router=from_chain.settings.interchain_router_contract_addr,
                destination_ICA_router=to_chain.settings.interchain_router_contract_addr,
                destination_router=to_chain.settings.swapper_contract_addr,
                destination_domain=destination_domain,
                slippage=slippage,
                bridged_amount=bridged_amount,
                swapper_contract_addr=to_chain.settings.swapper_contract_addr,
                destination_quote=destination_quote,
                bridge_fee=bridge_fee,
                xchain_fee=xchain_fee,
                salt=salt if salt else get_salt()
            ))

            origin_planner = setup_planner(
                quote=origin_quote,
                slippage=slippage,
                # money goes to the universal router (aka swapper) for bridging
                account=from_chain.settings.swapper_contract_addr, 
                router_address=from_chain.settings.swapper_contract_addr
            ) if origin_quote else None

            cmds, inputs = "", [] 

            if origin_planner:
                cmds += origin_planner.commands
                inputs.extend(origin_planner.inputs)
            if swap_data.destination_planner:
                cmds += swap_data.destination_planner.commands.replace("0x", "") if cmds != "" else swap_data.destination_planner.commands
                inputs.extend(swap_data.destination_planner.inputs)

            return await self.write(chain=self.chain_for_writes or from_chain, cmds=cmds, inputs=inputs, quote=quote, swap_data=swap_data, total_fee=total_fee)

    async def write(self, chain: AsyncChain, quote: SuperSwapQuote, swap_data: SuperSwapData, cmds: str, inputs: List[bytes], total_fee: int) -> str:
        swapper_contract_addr, from_token =  chain.settings.swapper_contract_addr, quote.from_token
        value = quote.amount_in * (10 ** from_token.decimals)
        # TODO: extend this to proper native token support
        message_fee = value + total_fee if quote.from_token.wrapped_token_address else total_fee
        async with chain:
            await chain.set_token_allowance(from_token, swapper_contract_addr, value)
            tx = await chain.sign_and_send_tx(chain.swapper.functions.execute(*[cmds, inputs]), value=message_fee)
            if swap_data.needs_relay:
                self.relayer.share_calls(
                    calls=serialize_ica_calls(swap_data.calls),
                    salt=swap_data.salt,
                    commitment_dispatch_tx=f'0x{tx["transactionHash"].hex()}',
                    origin_domain=swap_data.origin_domain
                )
            return f'0x{tx["transactionHash"].hex()}'
        

## Tests

In [None]:
from sugar.helpers import require_supersim
from fastcore.test import test_eq

require_supersim()

In [None]:
class MockSuperswapRelayer(SuperswapRelayer):
    """Mock relayer implementation for testing."""
    
    def __init__(self):
        self.calls_history: List[Dict[str, Any]] = []
    
    def share_calls(self, calls: List[dict], salt: str, commitment_dispatch_tx: str, origin_domain: int) -> None:
        """Mock implementation that records calls for verification."""
        call_data = {
            'calls': calls,
            'salt': salt,
            'commitment_dispatch_tx': commitment_dispatch_tx,
            'origin_domain': origin_domain
        }
        self.calls_history.append(call_data)
        print(f"Mock relayer received call: {call_data}")
    
    def get_last_call(self) -> Optional[Dict[str, Any]]:
        """Get the most recent call data."""
        return self.calls_history[-1] if self.calls_history else None
    
    def get_call_count(self) -> int:
        """Get the total number of calls made."""
        return len(self.calls_history)
    


In [None]:
from sugar import AsyncOPChainSimnet, AsyncOPChain, AsyncBaseChain, AsyncLiskChain

# try unsupported chains
async with AsyncOPChainSimnet() as op_sim:
    from_token, to_token = AsyncOPChain.velo, AsyncBaseChain.aero 
    error = None
    try:
        tx = await AsyncSuperswap(chain_for_writes=op_sim, relayer=MockSuperswapRelayer()).swap(from_token, to_token, amount=20)
    except ValueError as e:
        error = e
    test_eq(str(error), "Superswap only supports ['OP', 'Lisk', 'Uni']. Got OP -> Base")

async with AsyncOPChainSimnet() as op_sim:
    relayer=MockSuperswapRelayer()
    from_token, to_token = AsyncOPChain.velo, AsyncLiskChain.lsk
    tx = await AsyncSuperswap(chain_for_writes=op_sim, relayer=relayer).swap(from_token, to_token, amount=20)
    assert(tx.startswith("0x"))
    test_eq(relayer.get_call_count(), 1)
    last_call = relayer.get_last_call()
    test_eq(type(last_call["salt"]), str)
    test_eq(type(last_call["origin_domain"]), int)

Mock relayer received call: {'calls': [{'to': '0x0000000000000000000000001217bfe6c773eec6cc4a38b5dc45b92292b6e189', 'value': '0', 'data': '0x095ea7b300000000000000000000000001d40099fcd87c018969b0e8d4ab1633fb34763cffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff'}, {'to': '0x00000000000000000000000001d40099fcd87c018969b0e8d4ab1633fb34763c', 'value': '0', 'data': '0x24856bc3000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000002a1a100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000046000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()