# Magic and Imports

In [None]:
%matplotlib widget

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Generate Random Ternary Matrix

In [None]:
def generate_random_matrix(var, n, m=None):
    if var > 1:
        raise ValueError("Var must be <1")
    if m is None:
        m = n

    X = (np.random.rand(n, m) < var).astype(int)
    half = (np.random.rand(n, m) > 0.5)
    X[half] = -X[half]

    return X.astype(int)

In [None]:
X = generate_random_matrix(var=1, n=10)

In [None]:
X.mean(), X.var()

# Investigating Product of Independent Random Matrices

In [None]:
n = 128
ns = np.random.randint(low=512, high=513, size=10)

Xs = [generate_random_matrix(var=1/2, n=n) for _ in range(3)]

last_n = n
for n_i in ns:
    X = generate_random_matrix(var=1/last_n, n=last_n, m=n_i)
    Xs.append(X)
    last_n = n_i


In [None]:
for i, X in enumerate(Xs):
    print(f"{i} : dim {X.shape} m {np.abs(X.mean()):.2f} s {X.std():.2g}")

In [None]:
Ys = []
Y = np.eye(n).astype(int)
for i, X in enumerate(Xs):
    Y = Y @ X
    Ys.append(Y)

In [None]:
for i, Y in enumerate(Ys):
    print(f"{i} : m {np.abs(Y.mean()):.2f} s {Y.std():.0f}")

In [None]:
len(Ys)

In [None]:
FIG_NAME = "distribution"
plt.close(FIG_NAME)

num_ys = len(Ys)
ceil_sqrt_num_ys = np.ceil(np.sqrt(num_ys)).astype(int)

fig, axs = plt.subplots(ceil_sqrt_num_ys, ceil_sqrt_num_ys, num=FIG_NAME)
axs = np.array(axs).flatten()

for i, (ax, Y) in enumerate(zip(axs, Ys + [None]*100)):
    if Y is None:
        ax.plot(np.linspace(-1, 1), np.linspace(-1, 1), '-', color="red")
        ax.plot(np.linspace(1, -1), np.linspace(-1, 1), '-', color="red")
        ax.axis("off")
    else:
        ax.set_title(i)

        mean = Y.mean()
        std = Y.std()

        ax.hist(Y.flatten(), range=(mean-3*std, mean+3*std), density=True)
        #ax.hist(Y.flatten(), density=True, color="cyan")

        ax.axvline(mean, color="red", ls="-")
        ax.axvline(mean+2*std, color="red", ls="--")
        ax.axvline(mean-2*std, color="red", ls="--")
        

fig.tight_layout(pad=0.1)

plt.show()

# Fraction of Empty Rows

In [None]:
x = np.arange(13) 
x = 2 ** x
y = ((x-1)/x) ** x

In [None]:
FIG_NAME = "1/e"
plt.close(FIG_NAME)

fig, ax = plt.subplots(1, 1, num=FIG_NAME)
ax.plot(x, y, 'o--')
ax.axhline(1/np.e, ls='-', color="red")
ax.set_xscale("log")
ax.grid()

ax.set_ylim([0, 1])

ax.set_xlabel('N')
ax.set_ylabel("prob row being all 0")

plt.show()

In [None]:
for X in Xs[-len(ns):]:
    ts = []
    for X_ in X:
        t = (X_ == 0).all()
        ts.append(t)

print(np.mean(ts))
print(1/np.e)


# Bit Shifting

In [None]:
print(f"""
 -1 >> 1 = {(-1) >> 1}
  1 >> 1 = {(1) >> 1}
  0 >> 1 = {(0) >> 1}
""")

In [None]:
ints = np.arange(-200, 200).astype(int)
shifted = np.right_shift(ints, 1)


In [None]:
FIG_NAME = "bitshift"
plt.close(FIG_NAME)

fig, axs = plt.subplots(1, 2, num=FIG_NAME)

axs[0].plot(ints, shifted)
num = 20
half = len(ints) // 2
my_range = range(half - num//2, half + num//2)
axs[1].plot(ints[my_range], shifted[my_range])

for ax in axs:
    ax.grid()
    ax.set_xlabel("x")
    ax.set_ylabel("x >> 1")

fig.tight_layout()
plt.show()

# Normalisation by Bit-Shifting

In [None]:
BIT_SHIT = 1
BIT_SHIFT_VAR = 2 ** (BIT_SHIT*2)

BIT_SHIFT_VAR, BIT_SHIT

In [None]:
n = 128
ns = np.random.randint(low=512, high=513, size=10)

Xs = [generate_random_matrix(var=1, n=n) for _ in range(3)]

last_n = n
for n_i in ns:
    X = generate_random_matrix(var=BIT_SHIFT_VAR/last_n, n=last_n, m=n_i)
    Xs.append(X)
    last_n = n_i


In [None]:
for i, X in enumerate(Xs):
    print(f"{i} : dim {X.shape} m {np.abs(X.mean()):.2f} s {X.std():.2g}")

In [None]:
Ys = []
Y = np.eye(n).astype(int)
for i, X in enumerate(Xs):
    Y = Y @ X
    Y = np.right_shift(Y, BIT_SHIT)
    Ys.append(Y)

In [None]:
for i, Y in enumerate(Ys):
    print(f"{i} : m {np.abs(Y.mean()):.2f} s {Y.std():.0f}")

In [None]:
len(Ys)

In [None]:
FIG_NAME = "distribution"
plt.close(FIG_NAME)

num_ys = len(Ys)
ceil_sqrt_num_ys = np.ceil(np.sqrt(num_ys)).astype(int)

fig, axs = plt.subplots(ceil_sqrt_num_ys, ceil_sqrt_num_ys, num=FIG_NAME)
axs = np.array(axs).flatten()

for i, (ax, Y) in enumerate(zip(axs, Ys + [None]*100)):
    if Y is None:
        ax.plot(np.linspace(-1, 1), np.linspace(-1, 1), '-', color="red")
        ax.plot(np.linspace(1, -1), np.linspace(-1, 1), '-', color="red")
        ax.axis("off")
    else:
        ax.set_title(i)

        mean = Y.mean()
        std = Y.std()

        ax.hist(Y.flatten(), range=(mean-3*std, mean+3*std), density=True)
        #ax.hist(Y.flatten(), density=True, color="cyan")

        ax.axvline(mean, color="red", ls="-")
        ax.axvline(mean+2*std, color="red", ls="--")
        ax.axvline(mean-2*std, color="red", ls="--")
        

fig.tight_layout(pad=0.1)

plt.show()

# Fraction of Empty Rows - Again

In [None]:
x = np.arange(13)  + 1
x = 2 ** x
x = x[x > 4]

y = ((x-4)/x) ** x

In [None]:
FIG_NAME = "e^-4"
plt.close(FIG_NAME)

fig, ax = plt.subplots(1, 1, num=FIG_NAME)
ax.plot(x, y, 'o--')
ax.axhline(np.exp(-4), ls='-', color="red")
ax.set_xscale("log")
ax.grid()

ax.set_ylim([0, 1])

ax.set_xlabel('N')
ax.set_ylabel("prob row being all 0")

plt.show()

In [None]:
np.exp(-4)

In [None]:
bitshifts = np.arange(0, 5)
bitshift_vars = 2 ** (bitshifts*2)

probs = np.exp(-bitshift_vars)

In [None]:
FIG_NAME = "bitshift zero rows"
plt.close(FIG_NAME)

fig, ax = plt.subplots(1, 1, num=FIG_NAME)

ax.plot(bitshifts, probs)
ax.set_xlabel("bitshift n")
ax.set_ylabel("probability row being all 0")

ax.grid()
ax.set_xticks(bitshifts)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
n = 2 ** (np.arange(20) + 1)
bits_int = np.ceil(2*np.log(n)/np.log(2))
sparsity_for_equality = (2 / (1 + bits_int))

In [None]:
FIG_NAME = "sparsity"
plt.close(FIG_NAME)

fig, axs = plt.subplots(1, 2, num=FIG_NAME)

axs[0].plot(n, sparsity_for_equality, 'x')
axs[0].set_ylabel("density for equality")
axs[0].set_yticks(np.arange(21)/20)

axs[1].plot(n, sparsity_for_equality * n, 'x', label = "1.0 memory")
axs[1].plot(n, sparsity_for_equality * n / 2, 'x', label = "0.5 memory")
axs[1].plot(n, sparsity_for_equality * n / 10, 'x', label = "0.1 memory")
axs[1].plot(n, sparsity_for_equality * n / 100, 'x', label = "0.01 memory")
axs[1].set_ylabel("density for equality * n")
axs[1].set_yscale("log")
axs[1].legend()

for ax in axs:
    ax.set_xscale("log")
    ax.set_xlabel("n")
    ax.grid()

plt.show()

In [None]:
xx, yy = np.meshgrid(np.linspace(0, 7), np.linspace(0, 5))

In [None]:
xx_e = xx / np.log10(np.e)
yy_e = yy / np.log10(np.e)

bits_per_int = np.ceil(2*(xx_e)/(np.log(2)))
ln_memory_save = yy_e - np.log(2 / (1+bits_per_int)) - xx_e
log10_memory_save = ln_memory_save / np.log(10)

In [None]:
ln_memory_save.min()

In [None]:
from matplotlib.colors import TwoSlopeNorm

In [None]:
FIG_NAME = "image"
plt.close(FIG_NAME)

plt.subplots(1, 1, num=FIG_NAME)

plt.imshow(
    log10_memory_save, 
    cmap="seismic",
    norm=TwoSlopeNorm(vcenter=0),
    extent=(xx.min(), xx.max(), yy.max(), yy.min()),
    aspect="auto",
)

ax = plt.gca()
ax.invert_yaxis()

xticks = np.arange(xx.min(), xx.max()+1)
yticks = np.arange(yy.min(), yy.max()+1)

ax.set_xticks(xticks)
ax.set_xticklabels([f"$10^{int(p)}$" for p in xticks])
ax.set_yticks(yticks)
ax.set_yticklabels([f"$10^{int(p)}$" for p in yticks])

ax.set_xlabel("n")
ax.set_ylabel("density * n")

cticks = np.arange(np.floor(log10_memory_save.min()), np.ceil(log10_memory_save.max()), 1)
cbar = plt.colorbar(ticks=cticks)
cbar.ax.set_yticklabels(labels=[f"$10^{{{int(p)}}}$" for p in cticks])
cbar.ax.set_ylabel("Ratio of memory")

ax.grid()

plt.title("Ratio of sparse to non-sparse memory")

plt.show()

In [None]:

n = 2 ** (np.arange(20) + 1)
bits_int = np.ceil(2*np.log(n)/np.log(2))
sparsity_for_equality = (2 / (1 + bits_int))