In [None]:
import math

Q96 = 2**96


# Tick to sqrt price x 96
def tick_to_s96(tick):
    return int(1.0001 ** (tick / 2) * Q96)


def calc_x(L, s):
    assert s >= 0
    return L / s


def calc_y(L, s):
    assert s >= 0
    return L * s


def calc_dx(L, s_lo, s_hi):
    assert s_lo >= 0
    assert s_hi >= 0
    # TODO: separate calc when s_hi = inf
    if s_lo >= s_hi:
        return 0
    return L * (1 / s_lo - 1 / s_hi)


def calc_dy(L, s_lo, s_hi):
    assert s_lo >= 0
    assert s_hi >= 0
    # TODO: separate calc when s_lo = 0?
    if s_lo >= s_hi:
        return 0
    return L * (s_hi - s_lo)


# dx = L(1/s_lo - 1/s_hi)
# dy = L(s_hi - s_lo)
def calc_dx_to_s_lo(L, s_hi, dx):
    assert s_hi >= 0
    assert dx >= 0
    return 1 / (dx / L + 1 / s_hi)


def calc_dx_to_s_hi(L, s_lo, dx):
    assert s_lo >= 0
    assert dx >= 0
    return 1 / (-dx / L + 1 / s_lo)


def calc_dy_to_s_hi(L, s_lo, dy):
    assert s_lo >= 0
    assert dy >= 0
    return dy / L + s_lo


def calc_dy_to_s_lo(L, s_hi, dy):
    assert s_hi >= 0
    assert dy >= 0
    return -dy / L + s_hi


def next(pool, i, up):
    # TODO: return -inf or inf
    if i + 1 >= len(pool):
        if up:
            return (math.inf, math.inf, 0)
        else:
            return (0, 0, 0)
    return pool[i + 1]


def calc_amt_out(i, o, di, f):
    di *= 1 - f
    return o * di / (i + di)


# Optimal dy amount in (including fees) into pool A (dya -> dx -> dyb)
def calc_opt_dy_in(xa, ya, xb, yb, fa, fb):
    k0 = xa * ya * xb * yb * (1 - fa) * (1 - fb)
    k1 = (xb + xa * (1 - fb)) * (1 - fa)
    a = k1 * k1
    b = 2 * k1 * xb * ya
    c = (xb * ya) ** 2 - k0
    return (-b + math.sqrt(b * b - 4 * a * c)) / (2 * a)


def calc_opt_dya(la, sa, lb, sb, fa, fb):
    assert 0 <= sa <= sb
    xa = calc_x(la, sa)
    ya = calc_y(la, sa)
    xb = calc_x(lb, sb)
    yb = calc_y(lb, sb)
    dya = calc_opt_dy_in(xa, ya, xb, yb, fa, fb)
    assert dya >= 0
    dx = calc_amt_out(ya, xa, dya, fa)
    dyb = calc_amt_out(xb, yb, dx, fb)
    assert dyb >= dya
    s = calc_dy_to_s_hi(la, sa, dya * (1 - fa))
    assert sa <= s <= sb
    # TODO: check sb after swap
    # TODO: assert delta
    # assert dyb == calc_dy(lb, s, sb)
    return (dya, dyb, s)


def swap_to_sa_hi(xa, xb, la, sa, sa_lo, sa_hi, lb, sb, sb_lo, sb_hi, fa, fb):
    dya = calc_dy(la, sa_lo, sa_hi) / (1 - fa)
    sa = sa_hi
    sb = calc_dx_to_s_lo(lb, sb_hi, xa * (1 - fb))
    assert sb_lo <= sb
    dyb = calc_dy(lb, sb, sb_hi)
    return (dya, dyb, sa, sb)


def swap_to_sb_lo(xa, xb, la, sa, sa_lo, sa_hi, lb, sb, sb_lo, sb_hi, fa, fb):
    dx = min(xa, xb / (1 - fb))
    if dx == xa:
        sa = sa_hi
    else:
        sa = calc_dx_to_s_hi(la, s_lo, dx)
        assert sa <= sa_hi
    dya = calc_dy(la, sa_lo, sa) / (1 - fa)
    if dx == xa:
        sb = calc_dx_to_s_lo(lb, sb_hi, xa * (1 - fb))
        assert sb_lo <= sb
    else:
        sb = sb_lo
    dyb = calc_dy(lb, sb, sb_hi)
    return (dya, dyb, sa, sb)


# pa < pb
# dya -> dx -> dyb
# TODO: dxb -> dy -> dxa
# pool = [(lo, hi, liq)]
def calc_dya(pool_a, pool_b, fa, fb):
    (sa_lo, sa_hi, la) = pool_a[0]
    (sb_lo, sb_hi, lb) = pool_b[0]
    a = 0
    b = 0
    dya = 0
    dyb = 0
    sa = sa_lo
    sb = sb_hi
    while sa_lo < sb_hi:
        if sa_hi <= sb_lo:
            xa = calc_dx(la, sa_lo, sa_hi)
            xb = calc_dx(lb, sb_lo, sb_hi)
            if xa <= xb:
                # swap to sa_hi
                (da, db, sa, sb) = swap_to_sa_hi(
                    xa, xb, la, sa, sa_lo, sa_hi, lb, sb, sb_hi, fa, fb
                )
                dya += da
                dyb += db
            else:
                # swap to sb_lo
                (da, db, sa, sb) = swap_to_sb_lo(
                    xa, xb, la, sa, sa_lo, sa_hi, lb, sb, sb_hi, fa, fb
                )
                dya += da
                dyb += db
            # update price ranges
            if sb <= sa:
                break
            if sa == sa_hi:
                (sa_lo, sa_hi, la) = next(pool_a, a, True)
                a += 1
            else:
                sa_lo = sa
            if sb == sb_lo:
                (sb_lo, sb_hi, lb) = next(pool_b, b, False)
                b += 1
            else:
                sb_hi = sb
        else:
            (dya_opt, dyb_opt, s_opt) = calc_opt_dya(la, sa, lb, sb, fa, fb)
            if sa_hi < s_opt:
                # swap to sa_hi
                dya += calc_dy(la, sa_lo, sa_hi) / (1 - fa)
                dx = calc_dx(la, sa_lo, sa_hi)
                sa = sa_hi
                sb = calc_dx_to_s_lo(lb, s_hi, dx * (1 - fb))
                assert sb_lo <= sb
                dyb += calc_dy(lb, sb, sb_hi)
            elif s_opt < sb_lo:
                # swap to sb_lo
                xb = calc_dx(lb, sb_lo, sb_hi)
                xa = calc_dx(lb, sa_lo, sa_hi)
                dx = min(xa, xb / (1 - fb))
                if dx == xa:
                    sa = sa_hi
                else:
                    sa = calc_dx_to_s_hi(la, s_lo, dx)
                    assert sa <= sa_hi
                dya += calc_dy(la, sa_lo, sa) / (1 - fa)
                if dx == xa:
                    sb = calc_dx_to_s_lo(lb, sb_hi, dx * (1 - fb))
                    assert sb_lo <= sb
                else:
                    sb = sb_lo
                dyb += calc_dy(lb, sb, sb_hi)
            else:
                # swap to s_opt
                dya += dya_opt
                dyb += dyb_opt
                sa_lo = s_opt
                sb_hi = s_opt
                # TODO: check actual sa and sb
                sa = s_opt
                sb = s_opt
            # TODO: update prices
    assert dya <= dyb
    assert sa <= sb
    return (dya, dyb, sa, sb)


pool_a = []
pool_b = []
fa = 0
fb = 0
calc_dya(pool_a, pool_b, fa, fb)