# Mathematics of Arrays in Egglog


This notebook shows how if you define array operations as higher order functions, we can compose them and end up with a simpler algebra that just uses boolean and integers and functions.

We take as our input this MoA program, defined in [the PSI compiler](https://saulshanabrook.github.io/psi-compiler/src/):


```
main ()

{
  array Amts^3 <2 3 4>;
  array Ams^3 <2 3 4>;
  const array RAMY^3 <2 3 4>=<1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 
				11 12>;
  const array AMY^3 <2 3 4>=<9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9>;
  Amts=<2> take (<2> drop (RAMY cat AMY));
}
```

This result `Amts` is equivalent to `AMY`, since we are concatenating `RAMY` and `AMY` along the first axis, dropping the first 2 elements (which removes all of `RAMY`), and then taking the next 2 elements (which is all of `AMY`).

Compiling it produces this C program which copies AMY into Amts:

```c
#include <stdlib.h>
#include "moalib.e"

main()

{
  double *offset0;
  int i0;
  int i1;
  int i2;
  double *shift;
  double _RAMY[]={1.000000, 2.000000, 3.000000, 4.000000, 5.000000,
    6.000000, 7.000000, 8.000000, 9.000000, 10.000000,
    11.000000, 12.000000, 1.000000, 2.000000, 3.000000,
    4.000000, 5.000000, 6.000000, 7.000000, 8.000000,
    9.000000, 10.000000, 11.000000, 12.000000};
  double _AMY[]={9.000000, 9.000000, 9.000000, 9.000000, 9.000000,
    9.000000, 9.000000, 9.000000, 9.000000, 9.000000,
    9.000000, 9.000000, 9.000000, 9.000000, 9.000000,
    9.000000, 9.000000, 9.000000, 9.000000, 9.000000,
    9.000000, 9.000000, 9.000000, 9.000000};
  double _Y[]={8.000000, 8.000000, 8.000000, 8.000000, 8.000000,
    8.000000, 8.000000, 8.000000, 8.000000, 8.000000,
    8.000000, 8.000000, 8.000000, 8.000000, 8.000000,
    8.000000, 8.000000, 8.000000, 8.000000, 8.000000,
    8.000000, 8.000000, 8.000000, 8.000000};
  double _V[]={1.000000, 1.000000};
  double _Amts[2*3*4];

/*******
Amts=<2.000000> take (<2.000000> drop (RAMY cat AMY))
********/

  shift=_Amts+0*12+0*4+0;
  offset0=_AMY+0*12+0*4+0;
  for (i0=0; i0<2; i0++) {
    for (i1=0; i1<3; i1++) {
      for (i2=0; i2<4; i2++) {
        *(shift)= *(offset0);
        offset0+=1;
        shift+=1;
      }
    }
  }
```

What we want to show here is not the full compilation into C and into loops, but just the fact that by defining each array operation as a higher order function, we can compose them and end up with a simpler algebra that just uses boolean and integers and functions. This could then be compiled into loops. The hypothesis here is that we don't *lose* any information by erasing the `take`, `drop`, and `cat` operations and replacing them with their definitions in terms of functions.


In [27]:
from __future__ import annotations

from collections.abc import Callable

from egglog import *

array_ruleset = ruleset(name="array_ruleset")


class Boolean(Expr):
    def __init__(self, val: BoolLike) -> None: ...
    def if_bool(self, then: Int, else_: Int) -> Int: ...


class Int(Expr):
    def __init__(self, val: i64Like) -> None: ...
    def __eq__(self, other: Int) -> Boolean: ...  # type: ignore[override]
    def __lt__(self, other: Int) -> Boolean: ...
    def __add__(self, other: Int) -> Int: ...
    def __sub__(self, other: Int) -> Int: ...
    def __mul__(self, other: Int) -> Int: ...


@array_ruleset.register
def _int(i: i64, j: i64, x: Int, y: Int):
    yield rewrite(Int(i) + Int(j)).to(Int(i + j))
    yield rewrite(Int(i) - Int(j)).to(Int(i - j))
    yield rewrite(Int(i) * Int(j)).to(Int(i * j))
    yield rewrite(Int(i) == Int(i)).to(Boolean(True))
    yield rewrite(Int(i) == Int(j)).to(Boolean(False), i != j)
    yield rewrite(Int(i) < Int(j)).to(Boolean(True), i < j)
    yield rewrite(Int(i) < Int(j)).to(Boolean(False), i >= j)
    yield rewrite(Boolean(True).if_bool(x, y)).to(x)
    yield rewrite(Boolean(False).if_bool(x, y)).to(y)


@function
def vec_index(vec: Vec[Int], index: Int) -> Int: ...


@array_ruleset.register
def _vec_index(i: i64, xs: Vec[Int]):
    yield rewrite(vec_index(xs, Int(i))).to(xs[i])


class TupleInt(Expr, ruleset=array_ruleset):
    def __init__(self, length: Int, getitem_fn: Callable[[Int], Int]) -> None: ...
    def __getitem__(self, index: Int) -> Int: ...

    @property
    def length(self) -> Int: ...

    @classmethod
    def from_vec(cls, xs: Vec[Int]) -> TupleInt:
        return TupleInt(
            Int(xs.length()),
            lambda i: vec_index(xs, i),
        )


@array_ruleset.register
def _tuple_int(l: Int, fn: Callable[[Int], Int], i: Int):
    ti = TupleInt(l, fn)
    yield rewrite(ti.length).to(l)
    yield rewrite(ti[i]).to(fn(i))


class NDArray(Expr, ruleset=array_ruleset):
    def __init__(self, shape: TupleInt, idx_fn: Callable[[TupleInt], Int]) -> None: ...

    @classmethod
    def from_memory(cls, shape: TupleInt, values: TupleInt) -> NDArray:
        # Only work on ndim = 3 for now
        return NDArray(
            shape,
            lambda idx: values[
                idx[Int(0)] * (shape[Int(1)] * shape[Int(2)]) + idx[Int(1)] * shape[Int(2)] + idx[Int(2)]
            ],
        )

    @property
    def shape(self) -> TupleInt: ...

    def __getitem__(self, index: TupleInt) -> Int: ...


@array_ruleset.register
def _ndarray(shape: TupleInt, fn: Callable[[TupleInt], Int], idx: TupleInt):
    nda = NDArray(shape, fn)
    yield rewrite(nda.shape).to(shape)
    yield rewrite(nda[idx]).to(fn(idx))


@function(subsume=True, ruleset=array_ruleset)
def cat(l: NDArray, r: NDArray) -> NDArray:
    """
    Returns the concatenation of two arrays, they should have the same shape and the first dimension is added.
    """
    return NDArray(
        TupleInt(
            l.shape.length,
            lambda i: (i == Int(0)).if_bool(l.shape[Int(0)] + r.shape[Int(0)], l.shape[i]),
        ),
        lambda idx: (idx[Int(0)] < l.shape[Int(0)]).if_bool(
            l[idx], r[TupleInt(r.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - l.shape[Int(0)], idx[i]))]
        ),
    )


@function(subsume=True, ruleset=array_ruleset)
def drop(x: Int, arr: NDArray) -> NDArray:
    """
    Drops the first `x` elements off the front of the array `arr`.
    """
    return NDArray(
        TupleInt(
            arr.shape.length,
            lambda i: (i == Int(0)).if_bool(arr.shape[Int(0)] - x, arr.shape[i]),
        ),
        lambda idx: arr[
            TupleInt(
                arr.shape.length,
                #  Add x to the first index, so it skips the first x elements
                lambda i: (i == Int(0)).if_bool(idx[Int(0)] + x, idx[i]),
            )
        ],
    )


@function(subsume=True, ruleset=array_ruleset)
def take(x: Int, arr: NDArray) -> NDArray:
    """
    Takes the first `x` elements off the front of the array `arr`.
    """
    return NDArray(
        TupleInt(
            arr.shape.length,
            lambda i: (i == Int(0)).if_bool(x, arr.shape[i]),
        ),
        lambda idx: arr[idx],
    )


shape = TupleInt.from_vec(Vec(Int(2), Int(3), Int(4)))
RAMY = NDArray.from_memory(shape, constant("RAMY", TupleInt))
AMY = NDArray.from_memory(shape, constant("AMY", TupleInt))
Amts = take(Int(2), drop(Int(2), cat(RAMY, AMY)))
Amts

In [28]:
egraph = EGraph()
ndim = egraph.let("ndim", Amts.shape.length)
shape_1 = egraph.let("shape_1", Amts.shape[Int(0)])
shape_2 = egraph.let("shape_2", Amts.shape[Int(1)])
shape_3 = egraph.let("shape_3", Amts.shape[Int(2)])
idxs = TupleInt.from_vec(Vec(constant("i", Int), constant("j", Int), constant("k", Int)))
idxed = egraph.let("idxed", Amts[idxs])
amy_idxed = egraph.let("amy_idxed", AMY[idxs])

egraph.run(array_ruleset.saturate())
print(f"Amts.shape.length()={egraph.extract(ndim)}")
print(f"Amts.shape[0]={egraph.extract(shape_1)}")
print(f"Amts.shape[1]={egraph.extract(shape_2)}")
print(f"Amts.shape[2]={egraph.extract(shape_3)}")
print(f"Amts[i, j, k]={egraph.extract(idxed)}")
print(f"AMY[i, j, k]={egraph.extract(amy_idxed)}")

Amts.shape.length()=Int(3)
Amts.shape[0]=Int(2)
Amts.shape[1]=Int(3)
Amts.shape[2]=Int(4)
Amts[i, j, k]=((i + Int(2)) < Int(2)).if_bool(RAMY[(((i + Int(2)) * Int(12)) + (j * Int(4))) + k], AMY[((((i + Int(2)) - Int(2)) * Int(12)) + (j * Int(4))) + k])
AMY[i, j, k]=AMY[((i * Int(12)) + (j * Int(4))) + k]


We can see that Amts is equal to AMY, since they have the shape and indexing them produces the same result.

With some basic range analysis we could make them simplify to the same expression in the e-graph as well.

If we want, we can also see all the intermediate steps to get to the indexed result.

In [29]:
egraph = EGraph()
idxed = egraph.let("idxed", Amts[idxs])
egraph.saturate(array_ruleset, expr=idxed)

take(
    Int(2),
    drop(
        Int(2), cat(NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY), NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY))
    ),
)[TupleInt.from_vec(Vec[Int](i, j, k))] 

_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)
_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)
_NDArray_3 = NDArray(
    TupleInt(_NDArray_1.shape.length, lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),
    lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(
        _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]
    ),
)
_NDArray_4 = NDArray(
    TupleInt(_NDArray_3.shape.length, lambda i: (i == Int(0)).if_bool(_NDArray_3.shape[Int(0)] - Int(2), _NDArray_3.shape[i])),
    lambda idx: 

VisualizerWidget(egraphs=['{"nodes":{"primitive-i64-2":{"op":"2","children":[],"eclass":"i64-2","cost":1.0,"suâ€¦