In [19]:
from collections import defaultdict

class Bucket:
    def __init__(self):
        self.amt0 = 0
        self.amt1 = 0
        self.amt_out_rem = 0
        self.total = 0
        self.rem = 0
        self.amt_ins = defaultdict(int)

    def inc(self, user, amt):
        self.total += amt
        self.rem += amt
        self.amt_ins[user] += amt

    def dec(self, user, amt):
        self.total -= amt
        self.rem -= amt
        self.amt_ins[user] -= amt

    def __repr__(self):
        return f'amt0: {self.amt0} | amt1: {self.amt1} | total: {self.total} | amt_ins: {self.amt_ins}'

P1 = 1.01
    
# TODO: assert self.tick % tick_spacing = 0
class OrderBook:
    def __init__(self):
        tick = 0
        tick_spacing = 10
        assert tick % tick_spacing == 0
        self.tick = tick
        self.tick_spacing = tick_spacing
        self.dp = P1**tick
        self.p = P1**tick
        # bucket id => slot
        self.slots = defaultdict(int)
        # bucket id => slot => Bucket
        self.buckets = {}

    def get_bucket_id(self, tick, zero_for_one):
        return f'{tick}-{zero_for_one}'

    # TODO: exact amount out?
    def swap(self, amt_in, min_amt_out, zero_for_one):
        # P1 = 1.0001
        # p = P1^tick = price of X (ETH) in terms of Y (USDC)
        # px = y

        #            Y | X
        # - tick <- 1  | 0 -> + tick

        gas = 0
        oz = not zero_for_one
        p = self.p
        t = self.tick
        amt_in_rem = amt_in
        amt_out = 0
        while amt_in_rem > 0:
            assert gas < 100, "out of gas"
            gas += 1

            i = self.get_bucket_id(t, oz)
            s = self.slots[i]
            bucket = self.buckets.get(i, {}).get(s, None)
            print("tick", t, "rem", amt_in_rem, "slot", s, bucket)
            if bucket != None and bucket.total > 0:
                d_in = 0
                d_out = 0
                # TODO: math func to calculate max in and out (px = y)
                if zero_for_one:
                    # y <- x
                    max_out = bucket.amt1
                    max_in = max_out / p
                    d_out = min(max_out, amt_in_rem * p)
                    d_in = min(max_in, amt_in_rem)
                    bucket.amt0 += d_in
                    bucket.amt1 -= d_out
                    assert bucket.amt1 >= 0
                else:
                    # y -> x
                    max_out = bucket.amt0
                    max_in = max_out * p
                    d_out = min(max_out, amt_in_rem / p)
                    d_in = min(max_in, amt_in_rem)
                    bucket.amt0 -= d_out
                    bucket.amt1 += d_in
                    assert bucket.amt0 >= 0

                amt_out += d_out
                amt_in_rem -= d_in

                assert amt_in_rem >= 0

                if (zero_for_one and bucket.amt1 == 0) or (not zero_for_one and bucket.amt0 == 0):
                    self.slots[i] += 1
                    if zero_for_one:
                        assert bucket.amt0 > 0
                        assert bucket.amt1 == 0
                        bucket.amt_out_rem = bucket.amt0
                    else:
                        assert bucket.amt0 == 0
                        assert bucket.amt1 > 0
                        bucket.amt_out_rem = bucket.amt1
            if amt_in_rem > 0:
                # TODO: efficient way to find the next tick
                if zero_for_one:
                    t -= self.tick_spacing
                    p /= self.dp    
                else:
                    t += self.tick_spacing
                    p *= self.dp
        self.p = p
        self.tick = t
        assert amt_out >= min_amt_out

        return amt_out
    
    def inc(self, tick, zero_for_one, amt, **kwargs):
        msg_sender = kwargs["msg_sender"]

        assert tick % self.tick_spacing == 0
        if zero_for_one:
            assert self.tick < tick
        else:
            assert tick < self.tick

        i = self.get_bucket_id(tick, zero_for_one)
        s = self.slots[i]
        if self.buckets.get(i) is None:
            self.buckets[i] = {}
        if self.buckets[i].get(s) is None:
            self.buckets[i][s] = Bucket()

        bucket = self.buckets[i][s]
        bucket.inc(msg_sender, amt)
        if zero_for_one:
            bucket.amt0 += amt
        else:
            bucket.amt1 += amt

    def dec(self, tick, zero_for_one, amt, **kwargs):
        msg_sender = kwargs["msg_sender"]

        i = self.get_bucket_id(tick, zero_for_one)
        s = self.slots[i]
        bucket = self.buckets[i][s]

        assert amt <= bucket.amt_ins[msg_sender]

        amt0_out = bucket.amt0 * amt / bucket.total
        amt1_out = bucket.amt1 * amt / bucket.total

        bucket.amt0 -= amt0_out
        bucket.amt1 -= amt1_out
        # TODO: correct math? need to do vault math?
        bucket.dec(msg_sender, amt)

        # Dust
        if bucket.total == 0:
            amt0_out += bucket.amt0
            amt1_out += bucket.amt1
            del self.buckets[i][s]

        return (amt0_out, amt1_out)

    def take(self, slot, tick, zero_for_one, **kwargs):
        msg_sender = kwargs["msg_sender"]

        i = self.get_bucket_id(tick, zero_for_one)
        assert slot < self.slots[i]
        bucket = self.buckets[i][slot]

        total_out = 0
        if zero_for_one:
            assert bucket.amt0 == 0
            assert bucket.amt1 > 0
            total_out = bucket.amt1
        else:
            assert bucket.amt0 > 0
            assert bucket.amt1 == 0
            total_out = bucket.amt0

        amt_in = bucket.amt_ins[msg_sender]
        amt_out = total_out * amt_in / bucket.total

        bucket.amt_ins[msg_sender] = 0
        bucket.rem -= amt_in
        bucket.amt_out_rem -= amt_out

        # Dust
        if bucket.rem == 0:
            amt_out += bucket.amt_out_rem
            del self.buckets[i][slot]
        
        # TODO: check amt_out approx = P**tick * a_in?
        return amt_out


In [24]:
book = OrderBook()
book.inc(10, True, 100, msg_sender = "bob")
book.inc(10, True, 100, msg_sender = "alice")
book.inc(100, True, 110, msg_sender = "charlie")
book.swap(210, 0, False)

for (t, sb) in book.buckets.items():
    print(t)
    for (s, b) in sb.items():
        print(s, b)

book.take(0, 10, True, msg_sender = "alice")

tick 0 rem 210 slot 0 None
tick 10 rem 210 slot 0 amt0: 200 | amt1: 0 | total: 200 | amt_ins: defaultdict(<class 'int'>, {'bob': 100, 'alice': 100})
tick 20 rem 10.0 slot 0 None
tick 30 rem 10.0 slot 0 None
tick 40 rem 10.0 slot 0 None
tick 50 rem 10.0 slot 0 None
tick 60 rem 10.0 slot 0 None
tick 70 rem 10.0 slot 0 None
tick 80 rem 10.0 slot 0 None
tick 90 rem 10.0 slot 0 None
tick 100 rem 10.0 slot 0 amt0: 110 | amt1: 0 | total: 110 | amt_ins: defaultdict(<class 'int'>, {'charlie': 110})
10-True
0 amt0: 0 | amt1: 200.0 | total: 200 | amt_ins: defaultdict(<class 'int'>, {'bob': 100, 'alice': 100})
100-True
0 amt0: 100.0 | amt1: 10.0 | total: 110 | amt_ins: defaultdict(<class 'int'>, {'charlie': 110})


100.0

In [2]:
def calc_amts(tick0, tick1, liq):
    amt0 = 0
    amt1 = 0

    # L = px + y
    #  Y | X
    #  tick0
    if tick0 < tick1:
        # L = px
        # TODO: use rpow
        p = P ** tick1
        amt0 = liq / p
        amt1 = 0
    elif tick1 < tick0:
        # L = y
        amt0 = 0
        amt1 = liq
    else:
        # px = y = L / 2
        p = P ** tick0
        amt0 = liq / (2 * p)
        amt1 = liq / 2

    return (amt0, amt1)

0