Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: started working on multihop swaps for v3 #221

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_uniswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def client(request, web3: Web3, ganache: GanacheInstance):
ganache.eth_privkey,
web3=web3,
version=request.param,
use_estimate_gas=False, # see note in _build_and_send_tx
use_estimate_gas=True, # see note in _build_and_send_tx
)


Expand Down
167 changes: 96 additions & 71 deletions uniswap/uniswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def get_price_input(
token1: AddressLike, # output token
qty: int,
fee: int = None,
route: Optional[List[AddressLike]] = None,
route: List[AddressLike] = None,
) -> int:
"""Given `qty` amount of the input `token0`, returns the maximum output amount of output `token1`."""
if fee is None:
Expand All @@ -203,7 +203,7 @@ def get_price_output(
token1: AddressLike,
qty: int,
fee: int = None,
route: Optional[List[AddressLike]] = None,
route: List[AddressLike] = None,
) -> int:
"""Returns the minimum amount of `token0` required to buy `qty` amount of `token1`."""
if fee is None:
Expand Down Expand Up @@ -264,7 +264,7 @@ def _get_token_token_input_price(
token1: AddressLike, # output token
qty: int,
fee: int,
route: Optional[List[AddressLike]] = None,
route: List[AddressLike] = None,
) -> int:
"""
Public price (i.e. amount of output token received) for token to token trades with an exact input.
Expand All @@ -286,13 +286,14 @@ def _get_token_token_input_price(
if self.version == 2:
price: int = self.router.functions.getAmountsOut(qty, route).call()[-1]
elif self.version == 3:
# FIXME: How to calculate this properly? See https://docs.uniswap.org/reference/libraries/SqrtPriceMath
sqrtPriceLimitX96 = 0

if route:
# NOTE: to support custom routes we need to support the Path data encoding: https://github.com/Uniswap/uniswap-v3-periphery/blob/main/contracts/libraries/Path.sol
# result: tuple = self.quoter.functions.quoteExactInput(route, qty).call()
raise Exception("custom route not yet supported for v3")

# FIXME: How to calculate this properly? See https://docs.uniswap.org/reference/libraries/SqrtPriceMath
sqrtPriceLimitX96 = 0
price = self.quoter.functions.quoteExactInputSingle(
token0, token1, fee, qty, sqrtPriceLimitX96
).call()
Expand Down Expand Up @@ -580,68 +581,93 @@ def _token_to_eth_swap_input(
function = token_funcs.tokenToEthTransferInput(*func_params)
return self._build_and_send_tx(function)
elif self.version == 2:
if recipient is None:
recipient = self.address
amount_out_min = int(
(1 - slippage) * self._get_token_eth_input_price(input_token, qty, fee)
)
if fee_on_transfer:
func = (
self.router.functions.swapExactTokensForETHSupportingFeeOnTransferTokens
)
else:
func = self.router.functions.swapExactTokensForETH
return self._build_and_send_tx(
func(
qty,
amount_out_min,
[input_token, self.get_weth_address()],
recipient,
self._deadline(),
),
return self._token_to_eth_swap_input_v2(
input_token, qty, recipient, fee, slippage, fee_on_transfer
)
elif self.version == 3:
if recipient is None:
recipient = self.address

if fee_on_transfer:
raise Exception("fee on transfer not supported by Uniswap v3")

output_token = self.get_weth_address()
min_tokens_bought = int(
(1 - slippage)
* self._get_token_eth_input_price(input_token, qty, fee=fee)
return self._token_to_eth_swap_input_v3(
input_token, qty, recipient, fee, slippage
)
sqrtPriceLimitX96 = 0
else:
raise ValueError

swap_data = self.router.encodeABI(
fn_name="exactInputSingle",
args=[
(
input_token,
output_token,
fee,
ETH_ADDRESS,
self._deadline(),
qty,
min_tokens_bought,
sqrtPriceLimitX96,
)
],
def _token_to_eth_swap_input_v2(
self,
input_token: AddressLike,
qty: int,
recipient: Optional[AddressLike],
fee: int,
slippage: float,
fee_on_transfer: bool,
) -> HexBytes:
if recipient is None:
recipient = self.address
amount_out_min = int(
(1 - slippage) * self._get_token_eth_input_price(input_token, qty, fee)
)
if fee_on_transfer:
func = (
self.router.functions.swapExactTokensForETHSupportingFeeOnTransferTokens
)
else:
func = self.router.functions.swapExactTokensForETH
return self._build_and_send_tx(
func(
qty,
amount_out_min,
[input_token, self.get_weth_address()],
recipient,
self._deadline(),
),
)

unwrap_data = self.router.encodeABI(
fn_name="unwrapWETH9", args=[min_tokens_bought, recipient]
)
def _token_to_eth_swap_input_v3(
self,
input_token: AddressLike,
qty: int,
recipient: Optional[AddressLike],
fee: int,
slippage: float,
) -> HexBytes:
"""NOTE: Should always be called via the dispatcher `_token_to_eth_swap_input`"""
if recipient is None:
recipient = self.address

# Multicall
return self._build_and_send_tx(
self.router.functions.multicall([swap_data, unwrap_data]),
self._get_tx_params(),
)
output_token = self.get_weth_address()
min_tokens_bought = int(
(1 - slippage) * self._get_token_eth_input_price(input_token, qty, fee=fee)
)
sqrtPriceLimitX96 = 0

swap_data = self.router.encodeABI(
fn_name="exactInputSingle",
args=[
(
input_token,
output_token,
fee,
ETH_ADDRESS,
self._deadline(),
qty,
min_tokens_bought,
sqrtPriceLimitX96,
)
],
)

else:
raise ValueError
# NOTE: This will probably lead to dust WETH accumulation
unwrap_data = self.router.encodeABI(
fn_name="unwrapWETH9", args=[min_tokens_bought, recipient]
)

# Multicall
return self._build_and_send_tx(
self.router.functions.multicall([swap_data, unwrap_data]),
self._get_tx_params(),
)

def _token_to_token_swap_input(
self,
Expand Down Expand Up @@ -1110,13 +1136,17 @@ def _build_and_send_tx(
# `use_estimate_gas` needs to be True for networks like Arbitrum (can't assume 250000 gas),
# but it breaks tests for unknown reasons because estimateGas takes forever on some tx's.
# Maybe an issue with ganache? (got GC warnings once...)

# In case gas estimation is disabled.
# Without this set before gas estimation, it can lead to ganache stack overflow.
# See: https://github.com/trufflesuite/ganache/issues/985#issuecomment-998937085
transaction["gas"] = Wei(250000)

if self.use_estimate_gas:
# The Uniswap V3 UI uses 20% margin for transactions
transaction["gas"] = Wei(
int(self.w3.eth.estimate_gas(transaction) * 1.2)
)
else:
transaction["gas"] = Wei(250000)

signed_txn = self.w3.eth.account.sign_transaction(
transaction, private_key=self.private_key
Expand Down Expand Up @@ -1224,11 +1254,11 @@ def get_token(self, address: AddressLike, abi_name: str = "erc20") -> ERC20Token
raise InvalidToken(address)
try:
name = _name.decode()
except:
except Exception: # FIXME: Be more precise about exception to catch
name = _name
try:
symbol = _symbol.decode()
except:
except Exception: # FIXME: Be more precise about exception to catch
symbol = _symbol
return ERC20Token(symbol, address, name, decimals)

Expand All @@ -1255,11 +1285,11 @@ def get_raw_price(
if token_out == ETH_ADDRESS:
token_out = self.get_weth_address()

params: Tuple[ChecksumAddress, ChecksumAddress] = (
self.w3.toChecksumAddress(token_in),
self.w3.toChecksumAddress(token_out),
)
if self.version == 2:
params: Iterable[Union[ChecksumAddress,Optional[int]]] = [
self.w3.toChecksumAddress(token_in),
self.w3.toChecksumAddress(token_out),
]
pair_token = self.factory_contract.functions.getPair(*params).call()
token_in_erc20 = _load_contract_erc20(
self.w3, self.w3.toChecksumAddress(token_in)
Expand All @@ -1285,12 +1315,7 @@ def get_raw_price(

raw_price = token_out_balance / token_in_balance
else:
params = [
self.w3.toChecksumAddress(token_in),
self.w3.toChecksumAddress(token_out),
fee,
]
pool_address = self.factory_contract.functions.getPool(*params).call()
pool_address = self.factory_contract.functions.getPool(*params, fee).call()
pool_contract = _load_contract(
self.w3, abi_name="uniswap-v3/pool", address=pool_address
)
Expand All @@ -1316,7 +1341,7 @@ def estimate_price_impact(
token_out: AddressLike,
amount_in: int,
fee: int = None,
route: Optional[List[AddressLike]] = None,
route: List[AddressLike] = None,
) -> float:
"""
Returns the estimated price impact as a positive float (0.01 = 1%).
Expand Down
80 changes: 77 additions & 3 deletions uniswap/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import json
import functools
from typing import Union, List, Tuple
from typing import Union, List, Tuple, Any, Dict
from dataclasses import dataclass

from web3 import Web3
from web3.exceptions import NameNotFound
from eth_abi import encode_abi

from .types import AddressLike, Address, Contract

Expand Down Expand Up @@ -57,10 +59,82 @@ def _load_contract_erc20(w3: Web3, address: AddressLike) -> Contract:
return _load_contract(w3, "erc20", address)


def _encode_path(token_in: AddressLike, route: List[Tuple[int, AddressLike]]) -> bytes:
@dataclass
class Pool(dict):
token0: AddressLike
token1: AddressLike
fee: int


@dataclass
class Route:
pools: List[Pool]


def _token_seq_to_route(tokens: List[AddressLike], fee: int = 3000) -> Route:
return Route(
pools=[
Pool(token0, token1, fee) for token0, token1 in zip(tokens[:-1], tokens[1:])
]
)


def _encode_path(
token_in: AddressLike,
route: List[Tuple[int, AddressLike]],
# route: Route,
exactOutput: bool,
) -> bytes:
"""
Needed for multi-hop swaps in V3.

https://github.com/Uniswap/uniswap-v3-sdk/blob/1a74d5f0a31040fec4aeb1f83bba01d7c03f4870/src/utils/encodeRouteToPath.ts
"""
raise NotImplementedError
from functools import reduce

_route = _token_seq_to_route([token_in] + [token for fee, token in route])

def merge(acc: Dict[str, Any], pool: Pool) -> Dict[str, Any]:
"""Returns a dict with the keys: inputToken, path, types"""
index = 0 if not acc["types"] else None
inputToken = acc["inputToken"]
outputToken = pool.token1 if pool.token0 == inputToken else pool.token0
if index == 0:
return {
"inputToken": outputToken,
"types": ["address", "uint24", "address"],
"path": [inputToken, pool.fee, outputToken],
}
else:
return {
"inputToken": outputToken,
"types": [*acc["types"], "uint24", "address"],
"path": [*path, pool.fee, outputToken],
}

params = reduce(
merge,
_route.pools,
{"inputToken": _addr_to_str(token_in), "path": [], "types": []},
)
types = params["types"]
path = params["path"]

if exactOutput:
encoded: bytes = encode_abi(list(reversed(types)), list(reversed(path)))
else:
encoded = encode_abi(types, path)

return encoded


def test_encode_path() -> None:
"""Take tests from: https://github.com/Uniswap/uniswap-v3-sdk/blob/1a74d5f0a31040fec4aeb1f83bba01d7c03f4870/src/utils/encodeRouteToPath.test.ts"""
from uniswap.tokens import tokens

# TODO: Actually assert testcases
path = _encode_path(tokens["WETH"], [(3000, tokens["DAI"])], exactOutput=True)
print(path)

path = _encode_path(tokens["WETH"], [(3000, tokens["DAI"])], exactOutput=False)
print(path)