In [19]:
from collections import defaultdict

class Bucket:
    def __init__(self):
        self.amt0 = 0
        self.amt1 = 0
        self.amt_out_rem = 0
        self.total = 0
        self.rem = 0
        self.amt_ins = defaultdict(int)

    def inc(self, user, amt):
        self.total += amt
        self.rem += amt
        self.amt_ins[user] += amt

    def dec(self, user, amt):
        self.total -= amt
        self.rem -= amt
        self.amt_ins[user] -= amt

    def __repr__(self):
        return f'amt0: {self.amt0} | amt1: {self.amt1} | total: {self.total} | amt_ins: {self.amt_ins}'

P1 = 1.01
    
# TODO: assert self.tick % tick_spacing = 0
class OrderBook:
    def __init__(self):
        tick = 0
        tick_spacing = 10
        assert tick % tick_spacing == 0
        self.tick = tick
        self.tick_spacing = tick_spacing
        self.dp = P1**tick_spacing
        self.p = P1**tick
        # bucket id => slot
        self.slots = defaultdict(int)
        # bucket id => slot => Bucket
        self.buckets = {}

    def get_bucket_id(self, tick, zero_for_one):
        return f'{tick}-{zero_for_one}'

    # delta >= 0 = amt in, amt_out >= min_amt_out
    # delta < 0 = amt out, amt_in <= max_amt_in
    def swap(self, delta, min_max, zero_for_one):
        # P1 = 1.0001
        # p = P1^tick = price of X (ETH) in terms of Y (USDC)
        # px = y

        #            Y | X
        # - tick <- 1  | 0 -> + tick

        gas = 0
        oz = not zero_for_one
        p = self.p
        t = self.tick
        rem = delta
        amt_in = 0
        amt_out = 0
        while rem != 0:
            assert gas < 100, "out of gas"
            gas += 1

            i = self.get_bucket_id(t, oz)
            s = self.slots[i]
            bucket = self.buckets.get(i, {}).get(s, None)
            
            print("tick", t, "p", p, "rem", rem, "slot", s, bucket)
            
            if bucket != None and bucket.total > 0:
                d_in = 0
                d_out = 0
                # TODO: math func to calculate max in and out
                # px = y
                # rem >= 0 = amt_in_rem
                # rem < 0 = amt_out_rem 
                if zero_for_one:
                    # y <- x
                    max_out = bucket.amt1
                    max_in = max_out / p
                    if rem >= 0:
                        # rem = x, in = x, out = y
                        d_in = min(max_in, rem)
                        d_out = min(max_out, rem * p)
                    else:
                        # rem = y, in = x, out = y
                        d_in = min(max_in, -rem / p)
                        d_out = min(max_out, -rem)
                    bucket.amt0 += d_in
                    bucket.amt1 -= d_out
                    assert bucket.amt1 >= 0
                else:
                    # y -> x
                    max_out = bucket.amt0
                    max_in = max_out * p
                    if rem >= 0:
                        # rem = y, in = y, out = x
                        d_in = min(max_in, rem)
                        d_out = min(max_out, rem / p)
                    else:
                        # rem = x, in = y, out = x
                        d_in = min(max_in, -rem * p)
                        d_out = min(max_out, -rem)
                    bucket.amt0 -= d_out
                    bucket.amt1 += d_in
                    assert bucket.amt0 >= 0

                amt_in += d_in
                amt_out += d_out
                if rem >= 0:
                    rem -= d_in
                    assert rem >= 0, f'rem: {rem}'
                else:
                    rem += d_out
                    assert rem <= 0, f'rem: {rem}'

                if (zero_for_one and bucket.amt1 == 0) or (not zero_for_one and bucket.amt0 == 0):
                    self.slots[i] += 1
                    if zero_for_one:
                        assert bucket.amt0 > 0
                        assert bucket.amt1 == 0
                        bucket.amt_out_rem = bucket.amt0
                    else:
                        assert bucket.amt0 == 0
                        assert bucket.amt1 > 0
                        bucket.amt_out_rem = bucket.amt1
            if rem != 0:
                # TODO: efficient way to find the next tick
                if zero_for_one:
                    t -= self.tick_spacing
                    p /= self.dp    
                else:
                    t += self.tick_spacing
                    p *= self.dp
        self.p = p
        self.tick = t

        if delta >= 0:
            assert amt_out >= min_max, "out < min"
        else:
            assert amt_in <= min_max, "in > max"

        return (amt_in, amt_out)
    
    def inc(self, tick, zero_for_one, amt, **kwargs):
        msg_sender = kwargs["msg_sender"]

        assert tick % self.tick_spacing == 0
        if zero_for_one:
            assert self.tick < tick
        else:
            assert tick < self.tick

        i = self.get_bucket_id(tick, zero_for_one)
        s = self.slots[i]
        if self.buckets.get(i) is None:
            self.buckets[i] = {}
        if self.buckets[i].get(s) is None:
            self.buckets[i][s] = Bucket()

        bucket = self.buckets[i][s]
        bucket.inc(msg_sender, amt)
        if zero_for_one:
            bucket.amt0 += amt
        else:
            bucket.amt1 += amt

    def dec(self, tick, zero_for_one, amt, **kwargs):
        msg_sender = kwargs["msg_sender"]

        i = self.get_bucket_id(tick, zero_for_one)
        s = self.slots[i]
        bucket = self.buckets[i][s]

        assert amt <= bucket.amt_ins[msg_sender]

        amt0_out = bucket.amt0 * amt / bucket.total
        amt1_out = bucket.amt1 * amt / bucket.total

        bucket.amt0 -= amt0_out
        bucket.amt1 -= amt1_out
        # TODO: correct math? need to do vault math?
        bucket.dec(msg_sender, amt)

        # Dust
        if bucket.total == 0:
            amt0_out += bucket.amt0
            amt1_out += bucket.amt1
            del self.buckets[i][s]

        return (amt0_out, amt1_out)

    def take(self, slot, tick, zero_for_one, **kwargs):
        msg_sender = kwargs["msg_sender"]

        i = self.get_bucket_id(tick, zero_for_one)
        assert slot < self.slots[i]
        bucket = self.buckets[i][slot]

        total_out = 0
        if zero_for_one:
            assert bucket.amt0 == 0
            assert bucket.amt1 > 0
            total_out = bucket.amt1
        else:
            assert bucket.amt0 > 0
            assert bucket.amt1 == 0
            total_out = bucket.amt0

        amt_in = bucket.amt_ins[msg_sender]
        amt_out = total_out * amt_in / bucket.total

        bucket.amt_ins[msg_sender] = 0
        bucket.rem -= amt_in
        bucket.amt_out_rem -= amt_out

        # Dust
        if bucket.rem == 0:
            amt_out += bucket.amt_out_rem
            del self.buckets[i][slot]
        
        # TODO: check amt_out approx = P**tick * a_in?
        return amt_out


In [20]:
book = OrderBook()
book.inc(10, True, 100, msg_sender = "bob")
book.inc(10, True, 100, msg_sender = "alice")
book.inc(100, True, 110, msg_sender = "charlie")

(amt_in, amt_out) = book.swap(-210, 300, False)
print("in", amt_in, "out", amt_out)

for (t, sb) in book.buckets.items():
    print(t)
    for (s, b) in sb.items():
        print(s, b)

book.take(0, 10, True, msg_sender = "alice")

tick 0 p 1.0 rem -210 slot 0 None
tick 10 p 1.1046221254112045 rem -210 slot 0 amt0: 200 | amt1: 0 | total: 200 | amt_ins: defaultdict(<class 'int'>, {'bob': 100, 'alice': 100})
tick 20 p 1.2201900399479668 rem -10 slot 0 None
tick 30 p 1.3478489153329056 rem -10 slot 0 None
tick 40 p 1.4888637335882209 rem -10 slot 0 None
tick 50 p 1.644631821843882 rem -10 slot 0 None
tick 60 p 1.8166966985640904 rem -10 slot 0 None
tick 70 p 2.006763368395384 rem -10 slot 0 None
tick 80 p 2.216715217194257 rem -10 slot 0 None
tick 90 p 2.44863267464848 rem -10 slot 0 None
tick 100 p 2.7048138294215267 rem -10 slot 0 amt0: 110 | amt1: 0 | total: 110 | amt_ins: defaultdict(<class 'int'>, {'charlie': 110})
in 247.97256337645618 out 210
10-True
0 amt0: 0 | amt1: 220.9244250822409 | total: 200 | amt_ins: defaultdict(<class 'int'>, {'bob': 100, 'alice': 100})
100-True
0 amt0: 100 | amt1: 27.048138294215267 | total: 110 | amt_ins: defaultdict(<class 'int'>, {'charlie': 110})


110.46221254112044

In [34]:
from collections import defaultdict

i24_max = 2**23 - 1
i24_min = -2**23

# 2**24 = (2**8)**3 = 256**3
# int24 = 8 bits | 8 bits | 8 bits
u24 = 2**24
print("u24", u24)

# mapping u8 => u8 => u8
ticks = defaultdict(lambda: defaultdict(int))

i24 = 0xFFFFFF

def split(i24):
    hi = (i24 >> 16) & 0xFF
    mid = (i24 >> 8) & 0xFF
    lo = i24 & 0xFF
    return (hi, mid, lo)

def insert(ticks, i24):
    (hi, mid, lo) = split(i24)
    ticks[hi][mid] &= (1 << lo) 

def remove(ticks, i24):
    (hi, mid, lo) = split(i24)
    ticks[hi][mid] &= (0xFF & ~(1 << lo))    

# TODO: find
def find(ticks, i24, gt = True):
    # TODO
    # gt -> search tick bit to the left
    # msb(lo) -> mask = ~(msb & (msb - 1)) -> ticks[hi][mid] & mask
    # lt -> search tick bit to the right
    # 
    pass

def int24(val: int) -> int:
    # Mask to 24 bits
    val &= 0xFFFFFF
    # Convert to signed if the highest bit is set
    if val & 0x800000:
        val -= 0x1000000
    return val

u24 16777216
255 255 255
