# Simulation

## Prelude

In [None]:
!pip install matplotlib numpy mpmath

In [None]:
%matplotlib inline

import matplotlib
import numpy as np
from mpmath import mp
import matplotlib.pyplot as plt
import math
import warnings
from ipywidgets import interact, interactive, fixed, interact_manual

In [None]:
plt.rcParams['figure.dpi'] = 90
plt.rcParams['figure.figsize'] = [12.0, 8.0]
plt.rcParams['text.usetex'] = False

In [None]:
plt.rcParams

In [None]:
import watermark.watermark as watermark
print(watermark(machine=True, iso8601=True, python=True, iversions=True, globals_=globals()))

In [None]:
def intrange(bot, top, num = 100):
    bot = int(bot)
    top = int(top)
    x = np.linspace(bot, top, num = num)
    x = [max(min(int(x), top), bot) for x in x]
    return x

## EVM primitives

In [None]:
BASE = 2**256
SIGN = BASE // 2

In [None]:
def _valid(a):
    assert type(a) is int
    assert a >= 0
    assert a < BASE

def _abs(a):
    _valid(a)
    if a > SIGN:
        a = BASE - a
    assert a >= 0 and a < SIGN
    return a

def bnot(a):
    _valid(a)
    return a ^ (BASE - 1)

def bor(a, b):
    _valid(a)
    _valid(b)
    return a | b

def lt(a, b):
    _valid(a)
    _valid(b)
    return 1 if a < b else 0

def shl(a, b):
    _valid(a)
    _valid(b)
    return (b << a) % BASE

def shr(a, b):
    _valid(a)
    _valid(b)
    return (b >> a) % BASE

def sar(a, b):
    _valid(a)
    _valid(b)
    if b >= SIGN:
        b = BASE - b
        b >>= a
        b = BASE - max(b, 1)
        return b
    else:
        return (b >> a) % BASE

def add(a, b):
    _valid(a)
    _valid(b)
    return (a + b) % BASE

def sub(a, b):
    _valid(a)
    _valid(b)
    return (a - b) % BASE

def mul(a, b):
    _valid(a)
    _valid(b)
    return (a * b) % BASE

def div(a, b):
    _valid(a)
    _valid(b)
    return a // b

def sdiv(a, b):
    _valid(a)
    _valid(b)
    r = _abs(a) // _abs(b)
    assert r >= 0 and r < SIGN
    if r > 0 and ((a >= SIGN) ^ (b >= SIGN)):
        r = BASE - r
    assert r >= 0 and r < BASE    
    return r

def mod(a, b):
    _valid(a)
    _valid(b)
    return (a % b) % BASE

def smod(a, b):
    _valid(a)
    _valid(b)
    if b >= SIGN:
        b = BASE - b
    if a < SIGN:
        return (a % b) % BASE
    else:
        a = BASE - a
        return BASE - ((a % b) % BASE)

def mulmod(a, b, c):
    _valid(a)
    _valid(b)
    _valid(c)
    return ((a * b) % c) % BASE

## Ground truth

In [None]:
PRECISION = 1024 # Bits
mp.prec = PRECISION

In [None]:
MP_FIX_1 = mp.mpf(0x0000000000000000000000000000000080000000000000000000000000000000)

In [None]:
def fix_to_mp(x):
    _valid(x)
    if x < 2**255:
        return mp.mpf(x) / MP_FIX_1
    else:
        x = 2**256 - x
        return -mp.mpf(x) / MP_FIX_1

In [None]:
def fix_to_f(x):
    return float(fix_to_mp(x))

In [None]:
def mp_to_fix(x):
    # Closest approximation, clamping to MIN and MAX representable values, NaN maps to positive infinite
    if mp.isnan(x):
        return SIGN - 1
    if x >= 0:
        if x >= SIGN:
            return SIGN - 1
        x = int(mp.nint(x * MP_FIX_1))
    else:
        if -x > SIGN:
            return SIGN
        x = 2**256 - int(mp.nint(-x * MP_FIX_1))
    _valid(x)
    return x

In [None]:
def fmul(a, b):
    return bf_to_fix(fix_to_mp(a) * fix_to_mp(b))

In [None]:
def mp_log1(x):
    return mp_to_fix(mp.log(fix_to_mp(x)))

In [None]:
def mp_ln(x):
    return mp_to_fix(mp.log(fix_to_mp(x)))

In [None]:
def mp_exp(x):
    return mp_to_fix(mp.exp(fix_to_mp(x)))

## Cheby generator

In [None]:
def chebyfun(f, domain, degree=11):
    start, end = domain
    start = mp.mpf(start)
    end = mp.mpf(end)
    mid = start + (end - start) / 2
    
    # To make the polynomial evaluate nicely, we center x around zero using an offset CENTER
    # Alternatively we could use Clenshaw's algorithm
    # TODO: Centering helps a lot, what about scaling?
    coeffs = [mp_to_fix(x) for x in (mp.chebyfit(lambda x: f(x + mid), [start - mid, end - mid], degree))]
    mid = mp_to_fix(mid)
        
    # Solidity function
    def func(x):
        nonlocal coeffs
        x = sub(x, mid)
        r = coeffs[0]
        for coeff in coeffs[1:]:
            # r = (r * x >> 127) + coeff
            r = mul(r, x)
            r = sar(127, r)
            r = add(r, coeff)
        return r
    
    # Plot error
    x = intrange(mp_to_fix(start), mp_to_fix(end), num = 10000)
    xr = [fix_to_f(x) for x in x]
    y = np.array([mp_to_fix(f(fix_to_mp(x))) for x in x])
    yr = np.array([fix_to_f(y) for y in y])
    fy = np.array([func(x) for x in x])
    fyr = np.array([fix_to_f(y) for y in fy])
    plt.title('Error over domain')
    plt.xlabel("$x$")
    plt.ylabel("$\\log_2 \\vert f(x) - r \\vert$")
    plt.plot(xr, (fy - y).astype(float))
    error_fix = max(abs(fy - y))
    error_bits = max(np.log2(abs(fy - y).astype(float)))
    
    # Print Solidity
    print('// Chebyshev approximation on ({:.4g}, {:.4g}) deg {}.'.format(float(start), float(end), degree))
    print('// Max observed error {:.2g}, last {:.2g} bits.'.format(fix_to_f(error_fix), error_bits))
    if mid < SIGN:
        print('x -= 0x{:x}; // {:.3g}'.format(mid, fix_to_f(mid)))
    else:
        print('x += 0x{:x}; // {:.3g}'.format(BASE - mid, fix_to_f(mid)))
    if coeffs[0] < SIGN:
        print('int256 r = 0x{:x}; // {:.3g}'.format(coeffs[0], fix_to_f(coeffs[0])))
    else:
        print('int256 r = -0x{:x}; // {:.3g}'.format(BASE - coeffs[0], fix_to_f(coeffs[0])))
    for x in coeffs:
        if x < SIGN:
            print('r = ((r * x) >> 127) + 0x{:x}; // {:.3g}'.format(x, fix_to_f(x)))
        else:
            print('r = ((r * x) >> 127) - 0x{:x}; // {:.3g}'.format(BASE - x, fix_to_f(x)))


    return func

In [None]:
interact(lambda d:chebyfun(lambda x: mp.log(x), (0.8825, 0.99999999), degree=d), d=11)

In [None]:
my_log1 = chebyfun(mp.log, (0.8825, 0.99999999), degree=11)

In [None]:
my_exp = chebyfun(mp.exp, (0.0, 0.125), degree=17)

In [None]:
def chebyshev_nodes(domain, n):
    start, end = domain
    start = mp.mpf(start)
    end = mp.mpf(end)
    mid = (start + end) / 2
    amp = (end - start) / 2
    xs = []
    for i in range(n):
        f = mp.mpf(2 * i + 1) / mp.mpf(2 * n)
        xs += [mid - amp * mp.cos(mp.pi * f)]
    return xs

In [None]:
def least_squares(A, b):
    U, S, V = mp.svd_r(A)
    for i in range(len(S)):
        if S[i] != 0:
            S[i] = mp.one / S[i]
    return V.T * mp.diag(S) * U.T * b

In [None]:
def vandermonde(x, n):
    A = mp.matrix(len(x), n)
    for i, x in enumerate(x):
        r = mp.one
        for j in range(n):
            A[i, j] = r
            r *= x
    return A

In [None]:
f = lambda x: mp.exp(x + 0.0625)
d = (0.0 - 0.0625, 0.125 - 0.0625)
x = chebyshev_nodes(d, 17)
y = mp.matrix([f(x) for x in x])
c = least_squares(vandermonde(x, 17), y)
[hex(mp_to_fix(c)) for c in c]

In [None]:
def extrema(domain, f, c, derivative=None):
    # If no derivative of f is provided, produce one numerically
    if derivative is None:
        derivative = lambda x: mp.diff(f, x)
    
    # Create a derivative of the polynomial
    c = reversed([c for i, c in enumerate(c[0:-1:-1])])
    
    # Find extrema in the domain
    f = lambda x: polyval(c, x) - derivative(x)
    

In [None]:
def remez(x, n):
    A = mp.matrix(len(x), n + 1)
    for i, x in enumerate(x):
        r = mp.one
        for j in range(n):
            A[i, j] = r
            r *= x
        A[i, n] = mp.one if i % 2 == 0 else -mp.one
    return A

In [None]:
c = least_squares(remez(x, 17), y)
print(c)
[hex(mp_to_fix(c)) for c in c]

## LibFixMath

**TODO.** LibFixedMath uses signed int, not unsigned

In [None]:
FIXED_1 = 0x0000000000000000000000000000000080000000000000000000000000000000
MANTISA_MASK = 0x7fffffffffffffffffffffffffffffff
LN_MIN_VAL = 0x0000000000000000000000000000000000000000000000000000000733048c5a
LN_MAX_VAL = FIXED_1
EXP_MIN_VAL = BASE - 0x0000000000000000000000000000001ff0000000000000000000000000000000

In [None]:
def lib_toInteger(x):
    return div(x, FIXED_1)

In [None]:
def lib_toMantissa(x):
    if x > 0:
        return x & MANTISSA_MASK

In [None]:
def lib_mul(a, b):
    aI = lib_toInteger(a)
    aM = lib_toMantissa(a)
    aI = lib_toInteger(b)
    aM = lib_toMantissa(b)
    
    integerPart = mul(FIXED_1, mul(aI, bI))

In [None]:
def lib_log1(x):
    # Valid over range ~0.87 - 1.0
    r = 0
    y = sub(x, FIXED_1)
    z = y
    w = sdiv(mul(y, y), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x100000000000000000000000000000000, y)), 0x100000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x0aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, y)), 0x200000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x099999999999999999999999999999999, y)), 0x300000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x092492492492492492492492492492492, y)), 0x400000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x08e38e38e38e38e38e38e38e38e38e38e, y)), 0x500000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x08ba2e8ba2e8ba2e8ba2e8ba2e8ba2e8b, y)), 0x600000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x089d89d89d89d89d89d89d89d89d89d89, y)), 0x700000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x088888888888888888888888888888888, y)), 0x800000000000000000000000000000000))
    return r

In [None]:
t = []
def lib_ln(x, reductions='old', log='old'):
    global t
    _valid(x)
    if x > LN_MAX_VAL:
        raise "ln: value too large"
    if x == 0 or x >= SIGN:
        raise "ln: value too small"
    if x == FIXED_1:
        return 0
    if x <= LN_MIN_VAL:
        return EXP_MIN_VAL
    
    r = 0
    y = 0
    z = 0
    w = 0
    
    old_values = [
        0x00000000000000000000000000000000000000000001c8464f76164760000000,
        0x00000000000000000000000000000000000000f1aaddd7742e90000000000000,
        0x00000000000000000000000000000000000afe10820813d78000000000000000,
        0x0000000000000000000000000000000002582ab704279ec00000000000000000,
        0x000000000000000000000000000000001152aaa3bf81cc000000000000000000,
        0x000000000000000000000000000000002f16ac6c59de70000000000000000000,
        0x000000000000000000000000000000004da2cbf1be5828000000000000000000,
        0x0000000000000000000000000000000063afbe7ab2082c000000000000000000,
        0x0000000000000000000000000000000070f5a893b608861e1f58934f97aea57d,
    ]
    new_values = [
        0x1c8464f76164681e299a0,
        0xf1aaddd7742e56d32fb9f99744,
        0xafe10820813d65dfe6a33c07f738f,
        0x2582ab704279e8efd15e0265855c47b,
        0x1152aaa3bf81cb9fdb76eae12d029572,
        0x2f16ac6c59de6f8d5d6f63c1482a7c87,
        0x4da2cbf1be5827f9eb3ad1aa9866ebb4,
        0x63afbe7ab2082ba1a0ae5e4eb1b479dd,
        0x70f5a893b608861e1f58934f97aea57d,
    ]
    reduction_values = old_values if reductions == 'old' else new_values
    for i, v in enumerate(reduction_values):
        if x <= v:
            r = sub(r, 0x0000000000000000000000000000001000000000000000000000000000000000 >> i)
            x = sdiv(mul(x, FIXED_1), v)
    
    t += [x]
    if log == 'new':
        return add(r, my_log1(x))
    
    y = sub(x, FIXED_1)
    z = y
    w = sdiv(mul(y, y), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x100000000000000000000000000000000, y)), 0x100000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x0aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, y)), 0x200000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x099999999999999999999999999999999, y)), 0x300000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x092492492492492492492492492492492, y)), 0x400000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x08e38e38e38e38e38e38e38e38e38e38e, y)), 0x500000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x08ba2e8ba2e8ba2e8ba2e8ba2e8ba2e8b, y)), 0x600000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x089d89d89d89d89d89d89d89d89d89d89, y)), 0x700000000000000000000000000000000))
    z = sdiv(mul(z, w), FIXED_1)
    r = add(r, sdiv(mul(z, sub(0x088888888888888888888888888888888, y)), 0x800000000000000000000000000000000))
    
    return r

In [None]:
for i in range(-3,6):
    print(2**i)
    print(hex( mp_to_fix( mp.exp(-mp.mpf(2) ** (i)) )))

In [None]:
def my_mul(a, b):
    mm = mulmod(a, b, bnot(0))
    r0 = mul(a, b)
    r1 = sub(sub(mm, r0), lt(mm, r0))
    r1 = add(r1, mul(sar(256, a), b))
    r1 = add(r1, mul(sar(256, b), a))
    r = bor(shl(129, r1), shr(127, r0))
    return r

In [None]:
a = 2617556668622594272776707985386330127
b = BASE - 130444458470968929913751441309200329514

In [None]:
fix_to_f(a)

In [None]:
fix_to_f(b)

In [None]:
d = my_mul(a, b)
d

In [None]:
fix_to_f(d)

In [None]:
0.01538461538461533 * -0.7666836201434823

In [None]:
hex(sar(256, mp_to_fix(-2)))

# Evaluate

In [None]:
x = intrange(mp_to_fix(0.875), mp_to_fix(1.0), num = 10000)
xr = [fix_to_f(x) for x in x]

In [None]:
def fix_exp(x):
    return x.pow(1/(1-x))

In [None]:
y = np.array([fix_exp(x) for x in x])
yr = np.array([fix_to_f(y) for y in y])

### Log(1 + z)

In [None]:
x = intrange(mp_to_fix(0.875), mp_to_fix(1.0), num = 10000)
xr = [fix_to_f(x) for x in x]

In [None]:
y = np.array([mp_log1(x) for x in x])
yr = np.array([fix_to_f(y) for y in y])

In [None]:
ly = np.array([lib_log1(x) for x in x])
lyr = np.array([fix_to_f(y) for y in ly])

In [None]:
my = np.array([my_log1(x) for x in x])
myr = np.array([fix_to_f(y) for y in my])

In [None]:
plt.plot(xr, yr)
plt.plot(xr, lyr)
plt.plot(xr, myr)

In [None]:
plt.plot(xr, np.log2(abs(ly - y).astype(float)))
plt.plot(xr, np.log2(abs(my - y).astype(float)))

### Exp(x)

In [None]:
x = intrange(mp_to_fix(0.0), mp_to_fix(0.125), num = 10000)
xr = [fix_to_f(x) for x in x]

In [None]:
y = np.array([mp_exp(x) for x in x])
yr = np.array([fix_to_f(y) for y in y])

In [None]:
my = np.array([my_exp(x) for x in x])
myr = np.array([fix_to_f(y) for y in my])

In [None]:
plt.plot(xr, yr)
plt.plot(xr, myr)

In [None]:
plt.plot(xr, np.log2(abs(my - y).astype(float)))

### Ln(x)

In [None]:
t = []

In [None]:
x = intrange(LN_MIN_VAL, LN_MAX_VAL, num = 10000)
xr = [fix_to_f(x) for x in x]

In [None]:
y = np.array([mp_ln(x) for x in x])
yr = np.array([fix_to_f(y) for y in y])

In [None]:
ly = np.array([lib_ln(x) for x in x])
lyr = np.array([fix_to_f(y) for y in ly])

In [None]:
lz = np.array([lib_ln(x, reductions='new', log='new') for x in x])
lzr = np.array([fix_to_f(z) for z in lz])

In [None]:
plt.plot(xr, yr)
plt.plot(xr, lyr)
plt.plot(xr, lzr)

In [None]:
plt.plot(xr, np.log2(abs(ly - y).astype(float)))
plt.plot(xr, np.log2(abs(lz - y).astype(float)))

In [None]:
(fix_to_f(min(t)), fix_to_f(max(t)))

In [None]:
plt.plot([fix_to_f(t) for t in t])