# Demo the Arbitrageur Agent Logic

This notebook:

1. Fetches the latest crvusd contract data into Python objects using `crvusdsim`.
2. Fetches prices and `ExternalMarket`s from the `baseline` scenario.
3. Instantiates an Arbitrageur agent.
4. Checks if there are profitable arbitrages and performs them.

In [1]:
from crvusdsim.pool import get
from src.modules import ExternalMarket
from src.agents.arbitrageur import Arbitrageur
from src.sim.scenario import Scenario

%load_ext autoreload
%autoreload 2

In [2]:
"""
The way this should work:
1. Get all stableswap pools and relevant external markets
2. Create a list of all possible cycles of the form:
    [Flashswap] -> StableSwap Pool -> StableSwap Pool -> External Market
    TODO how to handle flashswap fee/slippage?
3. For each cycle: Optimize amount into first trade to maximize profit.
    This is done by creating a cycle object and optimizing the amount into
    the first trade. Then we can use `minimize_scalar` on the `execute()` method.
4. Execute most profitable cycle.
5. Repeat steps 3-4 until no profitable cycles remain.

TODO maybe there should be a class that holds all the pools. That way I can always just
pass in a single variable to any agent or function! This could just be a hashmap like:

pools: defaultdict(list) = {
    # <type> : []
    "SimCurveStableSwapPool": [],
    "SimLLAMMAPool": [],
    "SimCurvePool": [],
    "ExternalMarket": [],
}
"""

'\nThe way this should work:\n1. Get all stableswap pools and relevant external markets\n2. Create a list of all possible cycles of the form:\n    [Flashswap] -> StableSwap Pool -> StableSwap Pool -> External Market\n    TODO how to handle flashswap fee/slippage?\n3. For each cycle: Optimize amount into first trade to maximize profit.\n    This is done by creating a cycle object and optimizing the amount into\n    the first trade. Then we can use `minimize_scalar` on the `execute()` method.\n4. Execute most profitable cycle.\n5. Repeat steps 3-4 until no profitable cycles remain.\n\nTODO maybe there should be a class that holds all the pools. That way I can always just\npass in a single variable to any agent or function! This could just be a hashmap like:\n\npools: defaultdict(list) = {\n    # <type> : []\n    "SimCurveStableSwapPool": [],\n    "SimLLAMMAPool": [],\n    "SimCurvePool": [],\n    "ExternalMarket": [],\n}\n'

In [3]:
arbitrageur = Arbitrageur()

ETH_POOL = "weth"
(
    pool,
    controller,
    collateral_token,
    stablecoin,
    aggregator,
    stableswap_pools,
    peg_keepers,
    policy,
    factory,
) = get(ETH_POOL, bands_data="controller")

In [4]:
# Verify snapshot context works on StableSwap pools from crvusdsim
i, j = 1, 0

stableswap_pool = stableswap_pools[2]
high = stableswap_pool.get_max_trade_size(i, j)
prev_bals = stableswap_pool.balances.copy()

print("Pre Trade")
print("Balances", prev_bals)
print("High", high)
print("Allowance", stableswap_pool.coins[0].balanceOf[stableswap_pool.address])
print()

with stableswap_pool.use_snapshot_context():
    amt_out = stableswap_pool.trade(i, j, high)
    swap_bals = stableswap_pool.balances.copy()
    print("Post Trade")
    print("Amount out", amt_out)
    print("Balances", swap_bals)
    print("Allowance", stableswap_pool.coins[0].balanceOf[stableswap_pool.address])
    print()

new_bals = stableswap_pool.balances.copy()
print("After reversing snapshot context")
print("Balances", new_bals)
print("Allowance", stableswap_pool.coins[0].balanceOf[stableswap_pool.address])

assert prev_bals == new_bals, (prev_bals, new_bals)
assert swap_bals != new_bals

Pre Trade
Balances [9800026537423, 26112365622012323163758861]
High 12698760943745326035464421
Allowance 9800026537423

Post Trade
Amount out 9701056069421
Balances [98485366689, 38811126565757649199223282]
Allowance 98970468002

After reversing snapshot context
Balances [9800026537423, 26112365622012323163758861]
Allowance 9800026537423


In [5]:
config = "../src/configs/scenarios/baseline.json"
scenario = Scenario(config)
markets = scenario.generate_markets()
prices = scenario.generate_pricepaths(
    "../src/configs/prices/1h_1694894242_1700078242.json"
)

# Set External Market Prices
sample = prices[-1].prices
for in_token in markets:
    for out_token in markets[in_token]:
        markets[in_token][out_token].update_price(sample[in_token][out_token])

[INFO][02:42:36][root]-444465: Reading price config from ../src/configs/scenarios/baseline.json.
[INFO][02:42:36][root]-444465: Fetching 1inch quotes.


[INFO][02:42:41][root]-444465: We have 380160 quotes.
[INFO][02:42:41][root]-444465: Fitting external markets against 1inch quotes.
[INFO][02:42:41][root]-444465: Reading price config from ../src/configs/prices/1h_1694894242_1700078242.json.


In [6]:
# Create a list of all stableswap pools, LLAMMAs, and External Markets.
markets_lst = [item for subdict in markets.values() for item in subdict.values()]
pools = stableswap_pools + [pool] + markets_lst
print(
    f"There are {len(pools)} total pools:\nStableSwap: {len(stableswap_pools)}\nLLAMMA: 1\nExternal Markets: {len(markets_lst)}"
)

There are 35 total pools:
StableSwap: 4
LLAMMA: 1
External Markets: 30


In [7]:
import logging
from typing import List, Any, Tuple
from src.types import Swap, Cycle

# TODO add proper pool typing instead of Any


def shared_address(p1: Any, p2: Any, used: set = set()) -> set:
    """Check if two pools share coins by checking their addrs."""
    assert p1 != p2, ValueError("Cannot share coins with self.")

    if isinstance(p1, ExternalMarket) and isinstance(p2, ExternalMarket):
        # External markets are "directional", so they'll share both coins.
        # TODO refactor them to make them undirected
        return None, -1, -1

    c1 = [c.lower() for c in p1.coin_addresses]
    c2 = [c.lower() for c in p2.coin_addresses]

    shared = set(c1) & set(c2) - used
    assert len(shared) <= 1, NotImplementedError(
        f"We assume at most one shared coin. {type(p1), type(p2)}"
    )
    return shared


def get_shared_idxs(p1: Any, p2: Any) -> Tuple[int, int]:
    """Get the index of the shared coin in each pool."""
    shared = shared_address(p1, p2).pop()

    # FIXME inefficient to recreate c1 and c2
    c1 = [c.lower() for c in p1.coin_addresses]
    c2 = [c.lower() for c in p2.coin_addresses]

    return c1.index(shared), c2.index(shared)


class PoolGraph:
    # TODO move this to its own file

    def __init__(self, pools: List[Any]):
        self.pools = pools
        self.graph = self.create_graph()

    def create_graph(self) -> dict[Any, List[Any]]:
        graph = {}
        for pool in self.pools:
            assert len(pool.coin_addresses) == 2, NotImplementedError(
                "Only 2-coin pools"
            )
            graph[pool] = []
            for other in self.pools:
                if other != pool and bool(shared_address(pool, other)):
                    graph[pool].append(other)
        return graph

    def find_cycles(self, n: int = 3) -> List[Cycle]:
        # TODO currently assumes only one shared coin between
        # any two pools.
        assert len(pools) >= n, ValueError("Not enough pools to form a cycle.")
        cycles = []
        for pool in self.pools:
            self.dfs(pool, [pool], set(), cycles, n)
        valid = self.validate(cycles)
        logging.info(f"Found {len(valid)} valid cycles of length {n}.")
        return valid

    def can_traverse(self, curr: Any, nxt: Any, used: List[Any]) -> bool:
        if isinstance(curr, ExternalMarket) and isinstance(nxt, ExternalMarket):
            # Don't traverse between external markets
            return False
        return bool(shared_address(curr, nxt, used))

    def update_used_coins(self, used: List[Any], curr: Any, nxt: Any):
        used.update(shared_address(curr, nxt))

    def revert_used_coins(self, used: List[Any], curr: Any, nxt: Any):
        used.difference_update(shared_address(curr, nxt))

    def construct_cycle(self, path: List[Any], n: int) -> Cycle:
        trades = []
        for i, pool in enumerate(path):
            nxt = path[(i + 1) % n]
            idx, _ = get_shared_idxs(pool, nxt)  # token out
            trades.append(Swap(pool, idx ^ 1, idx, None))
        return Cycle(trades)

    def dfs(
        self, curr: Any, path: List[Any], used: List[Any], cycles: List[Cycle], n: int
    ):
        if len(path) == n:
            # Ensure cycle is closed
            shared = shared_address(path[0], path[-1], used)
            if bool(shared):
                cycles.append(self.construct_cycle(path, n))
            return

        for nxt in self.graph[curr]:
            if nxt in path:
                # Only visit each pool once per cycle
                continue
            if self.can_traverse(curr, nxt, used):
                path.append(nxt)
                self.update_used_coins(used, curr, nxt)
                self.dfs(nxt, path, used, cycles, n)
                path.pop()
                self.revert_used_coins(used, curr, nxt)

    def validate(self, cycles: List[Cycle]) -> List[Cycle]:
        """
        Filter for cycles that only has one ExternalMarket,
        and it's at the end of the cycle.
        """
        valid = []
        for cycle in cycles:
            pools = [t.pool for t in cycle.trades]
            if not isinstance(pools[-1], ExternalMarket):
                continue
            for pool in pools[:-1]:
                if isinstance(pool, ExternalMarket):
                    continue
            valid.append(cycle)
        return valid

In [8]:
n = 3
graph = PoolGraph(pools)
cycles = graph.find_cycles(n=n)
cycles

[INFO][02:42:41][root]-444465: Found 40 valid cycles of length 3.


[Cycle(Trades: [Swap(pool=<SimCurveStableSwapPool address=0x34d655069f4cac1547e4c8ca284ffff5ad4a8db0 chain=mainnet>, i=0, j=1, amt=None), Swap(pool=<SimCurveStableSwapPool address=0xca978a0528116dda3cba9acd3e68bc6191ca53d0 chain=mainnet>, i=1, j=0, amt=None), Swap(pool=External Market: USDP -> TUSD, i=0, j=1, amt=None)], Expected Profit: None),
 Cycle(Trades: [Swap(pool=<SimCurveStableSwapPool address=0x34d655069f4cac1547e4c8ca284ffff5ad4a8db0 chain=mainnet>, i=0, j=1, amt=None), Swap(pool=<SimCurveStableSwapPool address=0xca978a0528116dda3cba9acd3e68bc6191ca53d0 chain=mainnet>, i=1, j=0, amt=None), Swap(pool=External Market: TUSD -> USDP, i=1, j=0, amt=None)], Expected Profit: None),
 Cycle(Trades: [Swap(pool=<SimCurveStableSwapPool address=0x34d655069f4cac1547e4c8ca284ffff5ad4a8db0 chain=mainnet>, i=0, j=1, amt=None), Swap(pool=<SimCurveStableSwapPool address=0x4dece678ceceb27446b35c672dc7d61f30bad69e chain=mainnet>, i=1, j=0, amt=None), Swap(pool=External Market: USDC -> TUSD, i=0, 

In [14]:
# Testing
# 1. Cycle correctness (incl. closure)
# 2. Cycle length
# 3. Unique coin usage
# TODO move this to a unit test file

c = 0
for cycle in cycles:
    pools = [t.pool for t in cycle.trades]
    assert cycle.n == n, "Wrong length"
    used = shared_address(pools[0], pools[1])
    assert len(used) == 1
    used.update(shared_address(pools[1], pools[2]))
    assert len(used) == 2
    used.update(shared_address(pools[2], pools[0]))
    assert len(used) == 3
    c += 1