Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions apps/hip-3-pusher/config/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@ market_symbol = "BTC"
use_testnet = false
oracle_pusher_key_path = "/path/to/oracle_pusher_key.txt"
publish_interval = 3.0
publish_timeout = 5.0
enable_publish = false

[kms]
enable_kms = false
key_path = "/path/to/aws_kms_key_id.txt"
access_key_id_path = "/path/to/aws_access_key_id.txt"
secret_access_key_path = "/path/to/aws_secret_access_key.txt"
aws_region_name = "ap-northeast-1"
aws_kms_key_id_path = "/path/to/aws_kms_key_id.txt"

[lazer]
lazer_urls = ["wss://pyth-lazer-0.dourolabs.app/v1/stream", "wss://pyth-lazer-1.dourolabs.app/v1/stream"]
Expand Down
2 changes: 0 additions & 2 deletions apps/hip-3-pusher/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ dependencies = [
"opentelemetry-exporter-prometheus~=0.58b0",
"opentelemetry-sdk~=1.37.0",
"prometheus-client~=0.23.1",
"setuptools~=80.9",
"tenacity~=9.1.2",
"websockets~=15.0.1",
"wheel~=0.45.1",
]

[build-system]
Expand Down
19 changes: 13 additions & 6 deletions apps/hip-3-pusher/src/pusher/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from pydantic import BaseModel
from hyperliquid.utils.constants import MAINNET_API_URL, TESTNET_API_URL
from pydantic import BaseModel, FilePath, model_validator
from typing import Optional

STALE_TIMEOUT_SECONDS = 5


class KMSConfig(BaseModel):
enable_kms: bool
aws_region_name: str
key_path: str
access_key_id_path: str
secret_access_key_path: str
aws_kms_key_id_path: FilePath


class LazerConfig(BaseModel):
Expand All @@ -30,13 +29,21 @@ class HermesConfig(BaseModel):

class HyperliquidConfig(BaseModel):
hyperliquid_ws_urls: list[str]
push_urls: Optional[list[str]] = None
market_name: str
market_symbol: str
use_testnet: bool
oracle_pusher_key_path: str
oracle_pusher_key_path: FilePath
publish_interval: float
publish_timeout: float
enable_publish: bool

@model_validator(mode="after")
def set_default_urls(self):
if self.push_urls is None:
self.push_urls = [TESTNET_API_URL] if self.use_testnet else [MAINNET_API_URL]
return self


class Config(BaseModel):
stale_price_threshold_seconds: int
Expand Down
6 changes: 5 additions & 1 deletion apps/hip-3-pusher/src/pusher/exception.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
class StaleConnection(Exception):
class StaleConnectionError(Exception):
pass


class PushError(Exception):
pass
6 changes: 3 additions & 3 deletions apps/hip-3-pusher/src/pusher/hermes_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tenacity import retry, retry_if_exception_type, wait_exponential

from pusher.config import Config, STALE_TIMEOUT_SECONDS
from pusher.exception import StaleConnection
from pusher.exception import StaleConnectionError
from pusher.price_state import PriceState, PriceUpdate


Expand Down Expand Up @@ -34,7 +34,7 @@ async def subscribe_all(self):
await asyncio.gather(*(self.subscribe_single(url) for url in self.hermes_urls))

@retry(
retry=retry_if_exception_type((StaleConnection, websockets.ConnectionClosed)),
retry=retry_if_exception_type((StaleConnectionError, websockets.ConnectionClosed)),
wait=wait_exponential(multiplier=1, min=1, max=10),
reraise=True,
)
Expand All @@ -55,7 +55,7 @@ async def subscribe_single_inner(self, url):
data = json.loads(message)
self.parse_hermes_message(data)
except asyncio.TimeoutError:
raise StaleConnection(f"No messages in {STALE_TIMEOUT_SECONDS} seconds, reconnecting")
raise StaleConnectionError(f"No messages in {STALE_TIMEOUT_SECONDS} seconds, reconnecting")
except websockets.ConnectionClosed:
raise
except json.JSONDecodeError as e:
Expand Down
6 changes: 3 additions & 3 deletions apps/hip-3-pusher/src/pusher/hyperliquid_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time

from pusher.config import Config, STALE_TIMEOUT_SECONDS
from pusher.exception import StaleConnection
from pusher.exception import StaleConnectionError
from pusher.price_state import PriceState, PriceUpdate

# This will be in config, but note here.
Expand Down Expand Up @@ -35,7 +35,7 @@ async def subscribe_all(self):
await asyncio.gather(*(self.subscribe_single(hyperliquid_ws_url) for hyperliquid_ws_url in self.hyperliquid_ws_urls))

@retry(
retry=retry_if_exception_type((StaleConnection, websockets.ConnectionClosed)),
retry=retry_if_exception_type((StaleConnectionError, websockets.ConnectionClosed)),
wait=wait_exponential(multiplier=1, min=1, max=10),
reraise=True,
)
Expand Down Expand Up @@ -65,7 +65,7 @@ async def subscribe_single_inner(self, url):
else:
logger.error("Received unknown channel: {}", channel)
except asyncio.TimeoutError:
raise StaleConnection(f"No messages in {STALE_TIMEOUT_SECONDS} seconds, reconnecting...")
raise StaleConnectionError(f"No messages in {STALE_TIMEOUT_SECONDS} seconds, reconnecting...")
except websockets.ConnectionClosed:
raise
except json.JSONDecodeError as e:
Expand Down
105 changes: 58 additions & 47 deletions apps/hip-3-pusher/src/pusher/kms_signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,51 +6,61 @@
from eth_keys.datatypes import Signature
from eth_utils import keccak, to_hex
from hyperliquid.exchange import Exchange
from hyperliquid.utils.constants import TESTNET_API_URL, MAINNET_API_URL
from hyperliquid.utils.signing import get_timestamp_ms, action_hash, construct_phantom_agent, l1_payload
from loguru import logger
from pathlib import Path

from pusher.config import Config
from pusher.exception import PushError

SECP256K1_N_HALF = SECP256K1_N // 2


def _init_client():
# AWS_DEFAULT_REGION, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY should be set as environment variables
return boto3.client(
"kms",
# can specify an endpoint for e.g. LocalStack
# endpoint_url="http://localhost:4566"
)


class KMSSigner:
def __init__(self, config: Config):
use_testnet = config.hyperliquid.use_testnet
url = TESTNET_API_URL if use_testnet else MAINNET_API_URL
self.oracle_publisher_exchange: Exchange = Exchange(wallet=None, base_url=url)
self.client = self._init_client(config)
def __init__(self, config: Config, publisher_exchanges: list[Exchange]):
self.use_testnet = config.hyperliquid.use_testnet
self.publisher_exchanges = publisher_exchanges

# AWS client and public key load
self.client = _init_client()
try:
self._load_public_key(config.kms.aws_kms_key_id_path)
except Exception as e:
logger.exception("Failed to load public key from KMS; it might be incorrectly configured; error: {}", repr(e))
exit()
Comment on lines +37 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invoking exit() isn't ideal (often times this doesn't work the way you expect, esp in threaded/async workloads.) Lets just rethrow a chained exception. It should crash the program as expected. Something like

except Exception as e:
    logger.exception(...)
    throw Exception("Failed to load public key from KMS") from e


def _load_public_key(self, key_path: str):
# Fetch public key once so we can derive address and check recovery id
key_path = config.kms.key_path
self.key_id = open(key_path, "r").read().strip()
self.pubkey_der = self.client.get_public_key(KeyId=self.key_id)["PublicKey"]
self.aws_kms_key_id = Path(key_path).read_text().strip()
pubkey_der = self.client.get_public_key(KeyId=self.aws_kms_key_id)["PublicKey"]
self.pubkey = serialization.load_der_public_key(pubkey_der)
self._construct_pubkey_address_and_bytes()

def _construct_pubkey_address_and_bytes(self):
# Construct eth address to log
pub = serialization.load_der_public_key(self.pubkey_der)
numbers = pub.public_numbers()
numbers = self.pubkey.public_numbers()
x = numbers.x.to_bytes(32, "big")
y = numbers.y.to_bytes(32, "big")
uncompressed = b"\x04" + x + y
self.public_key_bytes = uncompressed
self.address = "0x" + keccak(uncompressed[1:])[-20:].hex()
logger.info("KMSSigner address: {}", self.address)

def _init_client(self, config):
aws_region_name = config.kms.aws_region_name
access_key_id_path = config.kms.access_key_id_path
access_key_id = open(access_key_id_path, "r").read().strip()
secret_access_key_path = config.kms.secret_access_key_path
secret_access_key = open(secret_access_key_path, "r").read().strip()

return boto3.client(
"kms",
region_name=aws_region_name,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
# can specify an endpoint for e.g. LocalStack
# endpoint_url="http://localhost:4566"
logger.info("public key loaded from KMS: {}", self.address)

# Parse KMS public key into uncompressed secp256k1 bytes
pubkey_bytes = self.pubkey.public_bytes(
serialization.Encoding.X962,
serialization.PublicFormat.UncompressedPoint,
)
# Strip leading 0x04 (uncompressed point indicator)
self.raw_pubkey_bytes = pubkey_bytes[1:]

def set_oracle(self, dex, oracle_pxs, all_mark_pxs, external_perp_pxs):
timestamp = get_timestamp_ms()
Expand All @@ -67,15 +77,24 @@ def set_oracle(self, dex, oracle_pxs, all_mark_pxs, external_perp_pxs):
},
}
signature = self.sign_l1_action(
action,
timestamp,
self.oracle_publisher_exchange.base_url == MAINNET_API_URL,
)
return self.oracle_publisher_exchange._post_action(
action,
signature,
timestamp,
action=action,
nonce=timestamp,
is_mainnet=not self.use_testnet,
)
return self._send_update(action, signature, timestamp)

def _send_update(self, action, signature, timestamp):
for exchange in self.publisher_exchanges:
try:
return exchange._post_action(
action=action,
signature=signature,
nonce=timestamp,
)
except Exception as e:
logger.exception("perp_deploy_set_oracle exception for endpoint: {} error: {}", exchange.base_url, repr(e))

raise PushError("all push endpoints failed")

def sign_l1_action(self, action, nonce, is_mainnet):
hash = action_hash(action, vault_address=None, nonce=nonce, expires_after=None)
Expand All @@ -88,7 +107,7 @@ def sign_l1_action(self, action, nonce, is_mainnet):
def sign_message(self, message_hash: bytes) -> dict:
# Send message hash to KMS for signing
resp = self.client.sign(
KeyId=self.key_id,
KeyId=self.aws_kms_key_id,
Message=message_hash,
MessageType="DIGEST",
SigningAlgorithm="ECDSA_SHA_256", # required for secp256k1
Expand All @@ -99,20 +118,12 @@ def sign_message(self, message_hash: bytes) -> dict:
# Ethereum requires low-s form
if s > SECP256K1_N_HALF:
s = SECP256K1_N - s
# Parse KMS public key into uncompressed secp256k1 bytes
# TODO: Pull this into init
pubkey = serialization.load_der_public_key(self.pubkey_der)
pubkey_bytes = pubkey.public_bytes(
serialization.Encoding.X962,
serialization.PublicFormat.UncompressedPoint,
)
# Strip leading 0x04 (uncompressed point indicator)
raw_pubkey_bytes = pubkey_bytes[1:]

# Try both recovery ids
for v in (0, 1):
sig_obj = Signature(vrs=(v, r, s))
recovered_pub = sig_obj.recover_public_key_from_msg_hash(message_hash)
if recovered_pub.to_bytes() == raw_pubkey_bytes:
if recovered_pub.to_bytes() == self.raw_pubkey_bytes:
return {
"r": to_hex(r),
"s": to_hex(s),
Expand Down
6 changes: 3 additions & 3 deletions apps/hip-3-pusher/src/pusher/lazer_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tenacity import retry, retry_if_exception_type, wait_exponential

from pusher.config import Config, STALE_TIMEOUT_SECONDS
from pusher.exception import StaleConnection
from pusher.exception import StaleConnectionError
from pusher.price_state import PriceState, PriceUpdate


Expand Down Expand Up @@ -38,7 +38,7 @@ async def subscribe_all(self):
await asyncio.gather(*(self.subscribe_single(router_url) for router_url in self.lazer_urls))

@retry(
retry=retry_if_exception_type((StaleConnection, websockets.ConnectionClosed)),
retry=retry_if_exception_type((StaleConnectionError, websockets.ConnectionClosed)),
wait=wait_exponential(multiplier=1, min=1, max=10),
reraise=True,
)
Expand All @@ -63,7 +63,7 @@ async def subscribe_single_inner(self, router_url):
data = json.loads(message)
self.parse_lazer_message(data)
except asyncio.TimeoutError:
raise StaleConnection(f"No messages in {STALE_TIMEOUT_SECONDS} seconds, reconnecting")
raise StaleConnectionError(f"No messages in {STALE_TIMEOUT_SECONDS} seconds, reconnecting")
except websockets.ConnectionClosed:
raise
except json.JSONDecodeError as e:
Expand Down
2 changes: 1 addition & 1 deletion apps/hip-3-pusher/src/pusher/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ async def main():
try:
asyncio.run(main())
except Exception as e:
logger.exception("Uncaught exception, exiting: {}", e)
logger.exception("Uncaught exception, exiting; error: {}", repr(e))
9 changes: 5 additions & 4 deletions apps/hip-3-pusher/src/pusher/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ def __init__(self, config: Config):
reader = PrometheusMetricReader()
# Meter is responsible for creating and recording metrics
set_meter_provider(MeterProvider(metric_readers=[reader]))
# TODO: sync version number and add?
self.meter = get_meter_provider().get_meter(METER_NAME)

self._init_metrics()

def _init_metrics(self):
Expand All @@ -35,5 +33,8 @@ def _init_metrics(self):
name="hip_3_pusher_failed_push_count",
description="Number of failed push attempts",
)

# TODO: labels/attributes
self.push_interval_histogram = self.meter.create_histogram(
name="hip_3_pusher_push_interval",
description="Interval between push requests (seconds)",
unit="s",
)
Loading