In [7]:
from collections import defaultdict

class Bucket:
    def __init__(self, zero_for_one):
        self.amt0 = 0
        self.amt1 = 0
        self.zero_for_one = zero_for_one
        self.total_in = 0
        self.total_in_rem = 0
        self.amt_out_rem = 0
        self.amt_ins = defaultdict(int)

    def inc(self, user, amt):
        self.total_in += amt
        self.total_in_rem += amt
        self.amt_ins[user] += amt

    def dec(self, user, amt):
        self.total_in -= amt
        self.total_in_rem -= amt
        self.amt_ins[user] -= amt

    def __repr__(self):
        return f'amt0: {self.amt0} | amt1: {self.amt1} | zero_for_one: {self.zero_for_one} | total_in: {self.total_in} | amt_ins: {self.amt_ins}'

P1 = 1.01

class OrderBook:
    def __init__(self):
        self.tick = 0
        self.tick_spacing = 10
        self.dp = P1**self.tick_spacing
        self.p = 1
        # tick => current slot
        self.slots = defaultdict(int)
        # tick => slot => Bucket
        self.buckets = {}

    # TODO: exact amount out?
    def swap(self, amt_in, min_amt_out, zero_for_one):
        # P1 = 1.0001
        # p = P1^tick = price of X (ETH) in terms of Y (USDC)
        # px = y

        #            Y | X
        # - tick <- 1  | 0 -> + tick

        gas = 0
        p = self.p
        t = self.tick
        amt_in_rem = amt_in
        amt_out = 0
        while amt_in_rem > 0:
            assert gas < 100
            gas += 1

            s = self.slots[t]
            bucket = self.buckets.get(t, {}).get(s, None)
            print(t, amt_in_rem, s, bucket)
            if bucket != None and bucket.total_in > 0:
                d_in = 0
                d_out = 0
                # TODO: math func to calculate max in and out (px = y)
                if zero_for_one:
                    # y <- x
                    max_out = bucket.amt1
                    max_in = max_out / p
                    d_out = min(max_out, amt_in_rem * p)
                    d_in = min(max_in, amt_in_rem)
                    bucket.amt0 += d_in
                    bucket.amt1 -= d_out
                    assert bucket.amt1 >= 0
                else:
                    # y -> x
                    max_out = bucket.amt0
                    max_in = max_out * p
                    d_out = min(max_out, amt_in_rem / p)
                    d_in = min(max_in, amt_in_rem)
                    bucket.amt0 -= d_out
                    bucket.amt1 += d_in
                    assert bucket.amt0 >= 0

                amt_out += d_out
                amt_in_rem -= d_in

                assert amt_in_rem >= 0
            
                if amt_in_rem > 0 or (bucket.zero_for_one and bucket.amt0 == 0) or (not bucket.zero_for_one and bucket.amt1 == 0):
                    self.slots[t] += 1
                    if zero_for_one:
                        assert bucket.amt0 > 0
                        assert bucket.amt1 == 0
                        bucket.amt_out_rem = bucket.amt0
                    else:
                        assert bucket.amt0 == 0
                        assert bucket.amt1 > 0
                        bucket.amt_out_rem = bucket.amt1
            # TODO: efficient way to find the next tick
            if amt_in_rem > 0:
                if zero_for_one:
                    t -= self.tick_spacing
                    p /= self.dp    
                else:
                    t += self.tick_spacing
                    p *= self.dp
        self.p = p
        self.tick = t
        assert amt_out >= min_amt_out

        return amt_out
    
    def inc(self, tick, amt, zero_for_one, **kwargs):
        msg_sender = kwargs["msg_sender"]

        assert tick % self.tick_spacing == 0

        if zero_for_one:
            assert self.tick < tick
        else:
            assert tick < self.tick

        s = self.slots[tick]
        if self.buckets.get(tick) is None:
            self.buckets[tick] = {}
        if self.buckets[tick].get(s) is None:
            self.buckets[tick][s] = Bucket(zero_for_one)

        bucket = self.buckets[tick][s]
        bucket.inc(msg_sender, amt)

        if zero_for_one:
            bucket.amt0 += amt
        else:
            bucket.amt1 += amt

    def dec(self, tick, amt, **kwargs):
        msg_sender = kwargs["msg_sender"]
        
        s = self.slots[tick]
        bucket = self.buckets[tick][s]

        assert amt <= bucket.amt_ins[msg_sender]

        amt0_out = bucekt.amt0 * amt / bucket.total_in
        amt1_out = bucekt.amt1 * amt / bucket.total_in

        bucket.amt0 -= amt0_out
        bucket.amt1 -= amt1_out
        # TODO: correct math? need to do vault math?
        bucket.dec(msg_sender, amt)

        # Dust
        if bucket.total_in == 0:
            amt0_out += bucket.amt0
            amt1_out += bucket.amt1
            del self.buckets[tick][s]

        return (amt0_out, amt1_out)

    def take(self, slot, tick, **kwargs):
        msg_sender = kwargs["msg_sender"]

        assert slot < self.slots[tick]
        bucket = self.buckets[tick][slot]

        assert (bucket.amt0 > 0 and bucket.amt1 == 0) or (bucket.amt0 == 0 and bucket.amt1 > 0)

        total_out = bucket.amt1 if bucket.zero_for_one else bucket.amt0
        assert total_out > 0

        amt_in = bucket.amt_ins[msg_sender]
        amt_out = total_out * amt_in / bucket.total_in

        bucket.amt_ins[msg_sender] = 0
        bucket.total_in_rem -= amt_in
        bucket.amt_out_rem -= amt_out

        # Dust
        if bucket.total_in_rem == 0:
            amt_out += bucket.amt_out_rem
            del self.buckets[tick][slot]
        
        # TODO: check amt_out approx = P**tick * a_in?
        return amt_out


In [17]:
book = OrderBook()
book.inc(10, 100, True, msg_sender = "bob")
book.inc(10, 100, True, msg_sender = "alice")
book.inc(30, 100, True, msg_sender = "charlie")
book.swap(250, 0, False)

for (t, sb) in book.buckets.items():
    print(t)
    for (s, b) in sb.items():
        print(s, b)

book.take(0, 10, msg_sender = "alice")

0 250 0 None
10 250 0 amt0: 200 | amt1: 0 | zero_for_one: True | total_in: 200 | amt_ins: defaultdict(<class 'int'>, {'bob': 100, 'alice': 100})
20 29.0755749177591 0 None
30 29.0755749177591 0 amt0: 100 | amt1: 0 | zero_for_one: True | total_in: 100 | amt_ins: defaultdict(<class 'int'>, {'charlie': 100})
10
0 amt0: 0 | amt1: 220.9244250822409 | zero_for_one: True | total_in: 200 | amt_ins: defaultdict(<class 'int'>, {'bob': 100, 'alice': 100})
30
0 amt0: 78.42816462067805 | amt1: 29.0755749177591 | zero_for_one: True | total_in: 100 | amt_ins: defaultdict(<class 'int'>, {'charlie': 100})


110.46221254112044