In [1]:
import argparse
import logging
import os
import pprint
from logging import getLogger

In [2]:
import numpy as np

In [3]:
from dotenv import find_dotenv, load_dotenv

In [4]:
from giza.agents import AgentResult, GizaAgent


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from addresses import ADDRESSES

In [6]:
# from lp_tools import get_tick_range
# from uni_helpers import (approve_token, check_allowance, close_position,
#                          get_all_user_positions, get_mint_params)

In [7]:
load_dotenv(find_dotenv())


True

In [8]:
dev_passphrase = os.environ.get("DEV_PASSPHRASE")
sepolia_rpc_url = os.environ.get("SEPOLIA_RPC_URL")

In [9]:
logging.basicConfig(level=logging.INFO)

In [10]:
def create_agent(
    model_id: int, version_id: int, chain: str, contracts: dict, account: str
):
    """
    Create a Giza agent for the volatility prediction model
    """
    agent = GizaAgent(
        contracts=contracts,
        id=model_id,
        version_id=version_id,
        chain=chain,
        account=account,
    )
    return agent

In [11]:
def predict(agent: GizaAgent, X: np.ndarray):
    """
    Predict the next day volatility.

    Args:
        X (np.ndarray): Input to the model.

    Returns:
        int: Predicted value.
    """
    prediction = agent.predict(input_feed={"val": X}, verifiable=True, job_size="XL")
    return prediction

In [12]:
def get_pred_val(prediction: AgentResult):
    """
    Get the value from the prediction.

    Args:
        prediction (dict): Prediction from the model.

    Returns:
        int: Predicted value.
    """
    # This will block the executon until the prediction has generated the proof
    # and the proof has been verified
    return prediction.value[0][0]

In [None]:
def  rebalance_lp(
    tokenWETH_amount: int,
    tokenUSDC_amount: int,
    pred_model_id: int,
    pred_version_id: int,
    account="giza1",
    chain=f"ethereum:sepolia:{sepolia_rpc_url}",
    nft_id=None,
):
    logger = getLogger("agent_logger")
    networks.parse_network_choice(f"ethereum:sepolia:{sepolia_rpc_url}").__enter__()
    chain_id = chain.chain_id
    # weth_mint_amount = 0.01
    weth_mint_amount = tokenWETH_amount
    pool_fee = 3000
    uni = Contract(ADDRESSES["UNI"][chain_id])
    weth = Contract(ADDRESSES["WETH"][chain_id])
    # wbtc = Contract(ADDRESSES["WETH"][chain_id])
    wusdc = Contract(ADDRESSES["USDC"][chain_id])
    # wbtc = Contract('0x66194f6c999b28965e0303a84cb8b797273b6b8b')
    weth_decimals = weth.decimals()
    # wbtc_decimals = wbtc.decimals()
    uni_decimals = uni.decimals()
    wusdc_decimals = wusdc.decimals()
    weth_mint_amount = int(weth_mint_amount * 10**weth_decimals)
    uni_mint_amount = int(0.5 * weth_mint_amount)
    contracts = {
        "nft_manager": nft_manager_address,
        "tokenA": tokenA_address,
        "tokenB": tokenB_address,
        "pool": pool_address,
    }




    with accounts.use_sender("giza1"):
        # print(f"Minting {weth_mint_amount/10**weth_decimals} WETH")
        # weth.deposit(value=weth_mint_amount)
        # print("Approving WETH for swap")
        # weth.approve(swap_router.address, weth_mint_amount)
        swap_params = {
            "tokenIn": weth.address,
            "tokenOut": wusdc.address,
            "fee": pool_fee,
            "recipient": wallet.address,
            "amountIn": weth_mint_amount,
            "amountOutMinimum": 0,
            "sqrtPriceLimitX96": 0,
        }
        swap_params = tuple(swap_params.values())
        print("Swapping WETH for USDC")
        amountOut = swap_router.exactInputSingle(swap_params)
        print(f"Successfully minted {uni_mint_amount/10**uni_decimals} USDC")
    
    print(f"Your WETH balance: {weth.balanceOf(wallet.address)/10**weth_decimals}")
    print(f"Your WUSDC balance: {wusdc.balanceOf(wallet.address)/10**wusdc_decimals}")
    

In [None]:
if __name__ == "__main__":
    # Create the parser
    # parser = argparse.ArgumentParser()

    # # Add arguments
    # parser.add_argument("--model-id", metavar="M", type=int, help="model-id")
    # parser.add_argument("--version-id", metavar="V", type=int, help="version-id")
    # parser.add_argument("--tokenA-amount", metavar="A", type=int, help="tokenA-amount")
    # parser.add_argument("--tokenB-amount", metavar="B", type=int, help="tokenB-amount")

    # # Parse arguments
    # args = parser.parse_args()

    # for now we will just get it from .env
    model_id = os.environ.get("MODEL_ID")
    version_id = os.environ.get("VERSION_ID")
    
    MODEL_ID = model_id
    VERSION_ID = version_id
    tokenA_amount = args.tokenA_amount
    tokenB_amount = args.tokenB_amount

    rebalance_lp(tokenA_amount, tokenB_amount, MODEL_ID, VERSION_ID)