<a href="https://colab.research.google.com/github/srush/Tensor-Puzzles-Penzai/blob/main/Tensor_Puzzlers_Penzai.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tensor Puzzles - Penzai Addition
- by [Sasha Rush](http://rush-nlp.com) - [srush_nlp](https://twitter.com/srush_nlp)


This is a version of the [tensor puzzles](https://github.com/srush/Tensor-Puzzles) implemented the [Penzai](https://penzai.readthedocs.io/en/stable/notebooks/named_axes.html) library.


I recommend running in Colab. Click here and copy the notebook to get start.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/Tensor-Puzzles/blob/main/Tensor%20Puzzlers.ipynb)

In [178]:
#!pip install -qqq jaxtyping hypothesis pytest penzai
import jax.numpy as np
import numpy as onp
from penzai import pz
arange = pz.nx.arange
where = pz.nx.nmap(np.where)
wrap = pz.nx.wrap
pz.ts.register_as_default()

# Optional automatic array visualization extras:
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer(force_continuous=True, around_zero=True,  prefers_column=["j"], prefers_row=["i"]))

In [179]:
import inspect
import random
from jaxtyping import Int32
NamedArray = pz.nx.NamedArray
def make_test(name, problem, problem_spec, add_sizes=[],
              init_size = {},
              constraint=lambda d: d):
    args = {}
    signature = inspect.signature(problem)
    for n, p in signature.parameters.items():
        args[n] = [d.name for d in p.annotation.dims]
    args["return"] = [d.name for d in signature.return_annotation.dims]

    def make_instance():
        example = {}
        reg = {}
        sizes = {}
        for k in init_size:
            sizes[k] = init_size[k]
        for n in args:
            size = {}
            for name in args[n]:
                if name[0] not in sizes:
                    sizes[name[0]] = random.randint(2, 7)
                size[name] = sizes[name[0]]
            if "_s" in n:
                l = list(size.keys())[0]
                example[n] = pz.nx.arange(l, size[l])
            else:
                v = onp.random.randint(-5, 5, list(size.values()))
                example[n] = pz.nx.wrap(v).tag(*args[n])
        example = constraint(example)
        for n in args:
            x = example[n]
            x = x.untag(*args[n])
            reg[n] = x.unwrap().tolist()
            if len(args[n]) == 0:
                reg[n] = [0]
        return example, reg

    examples = []
    correct = 0
    for i in range(3):
        example, reg = make_instance()
        # out = example["return"].tolist()
        del example["return"]
        problem_spec(*reg.values())
        if len(reg["return"]) == 1:
            reg["return"] = reg["return"][0]
        yours = None
        yours = problem(**example)
        example["target"] = wrap(reg["return"])
        example["target"] = example["target"].tag(*args["return"])
        if yours is not None:
            example["yours"] = yours
        same = example["target"] == example["yours"]
        if same.untag(*same.named_shape.keys()).unwrap().all():
            correct += 1
        examples.append(example)
    if correct == 3:
        print("Correct")
    else:
        print("Failure")
    return examples


## Rules



1. Each puzzle needs to be solved in 1 line (<80 columns) of code.
2. You are only allowed to use `contract`, `where` and indexing.
3. You are *not allowed* anything else. No `view`, `sum`, `take`, `squeeze`, `tensor`.

In [180]:
# Example of named infix ops.
a = [arange("i", 4), arange("j", 5)]
b = [arange("j", 3), arange("i", 2)]

[{"a": a, "b":b, "ret": a + b} for a, b in zip(a, b)]

In [181]:
# Example of where
examples = [(wrap([False, True], "i"), wrap([1, 1], "i"), wrap([-1, 0], "i")),
            (wrap([[False, True], [True, False]], "i", "j"), wrap([0, 1], "i"), wrap([-1, 0], "j")),
           ]
[{"q": q, "a":a, "b":b, "ret": where(q, a, b)} for q, a, b in examples]

In [203]:
# Example of contraction
def contract(n, *ts):
    t = 1
    for t2 in ts:
        t = t * t2
    return pz.nx.nmap(np.sum)(t.untag(n))

a = [arange("i", 4), arange("j", 5)]
b = [arange("j", 3), arange("i", 2)]

[{"a": a, "b":b, "ret": contract("i", a * b)} for a, b in zip(a, b)]

## Puzzle 1 - ones

Compute [ones](https://numpy.org/doc/stable/reference/generated/numpy.ones.html) - the vector of all ones.

In [183]:
def ones_spec(i_s, out):
    for i in i_s:
        out[i] = 1

def ones(i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
    return i_s * 0 + 1

make_test("one", ones, ones_spec)

Correct


## Puzzle 2 - sum

Compute [sum](https://numpy.org/doc/stable/reference/generated/numpy.sum.html) - the sum of a vector.

In [184]:
def sum_spec(a, out):
    for i in range(len(a)):
        out[0] = out[0] + a[i]

def sum(a: Int32[NamedArray, "i"]) -> Int32[NamedArray, ""]:
    return contract("i", a)

make_test("sum", sum, sum_spec)

Correct


## Puzzle 3 - outer

Compute [outer](https://numpy.org/doc/stable/reference/generated/numpy.outer.html) - the outer product of two vectors.

In [185]:
def outer_spec(a, b, out):
    for i in range(len(out)):
        for j in range(len(out[0])):
            out[i][j] = a[i] * b[j]

def outer(a: Int32[NamedArray, "i"], b : Int32[NamedArray, "j"]) -> Int32[NamedArray, "i j"]:
    return a * b

make_test("outer", outer, outer_spec)

Correct


## Puzzle 4 - diag

Compute [diag](https://numpy.org/doc/stable/reference/generated/numpy.diag.html) - the diagonal vector of a square matrix.

In [186]:
def diag_spec(a, i1_s, out):
    for i in range(len(a)):
        out[i] = a[i][i]

def diag(a: Int32[NamedArray, "i1 i2"], i1_s: Int32[NamedArray, "i1"]) -> Int32[NamedArray, "i1"]:
    return a[{"i1": i1_s, "i2": i1_s}]


make_test("diag", diag, diag_spec)

Correct


## Puzzle 5 - eye

Compute [eye](https://numpy.org/doc/stable/reference/generated/numpy.eye.html) - the identity matrix.

In [187]:
def eye_spec(i1_s, i2_s, out):
    for i in i1_s:
        for j in i2_s:
            if i == j:
                out[i][j] = 1
            else:
                out[i][j] = 0

def eye(i1_s: Int32[NamedArray, "i1"], i2_s : Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i1 i2"]:
    return where(i1_s == i2_s, 1, 0)


make_test("eye", eye, eye_spec)

Correct


## Puzzle 6 - triu

Compute [triu](https://numpy.org/doc/stable/reference/generated/numpy.triu.html) - the upper triangular matrix.

In [188]:
def triu_spec(i1_s, i2_s, out):
    for i in i1_s:
        for j in i2_s:
            if i <= j:
                out[i][j] = 1
            else:
                out[i][j] = 0

def triu(i1_s: Int32[NamedArray, "i1"], i2_s : Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i1 i2"]:
    return where(i1_s <= i2_s, 1, 0)


make_test("triu", triu, triu_spec)

Correct


## Puzzle 7 - cumsum

Compute [cumsum](https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html) - the cumulative sum.

In [189]:
def cumsum_spec(a, i1_s, i2_s, out):
    total = 0
    for i in range(len(out)):
        out[i] = total + a[i]
        total += a[i]

def cumsum(a: Int32[NamedArray, "i1"], i1_s : Int32[NamedArray, "i1"], i2_s: Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i2"]:
    return contract("i1", where(i1_s <= i2_s, 1, 0), a)

make_test("cumsum", cumsum, cumsum_spec)

Correct


## Puzzle 8 - diff

Compute [diff](https://numpy.org/doc/stable/reference/generated/numpy.diff.html) - the running difference.

In [190]:
def diff_spec(a, i_s, out):
    out[0] = a[0]
    for i in range(0, len(out)):
        out[i] = a[i] - a[(i - 1)]

def diff(a: Int32[NamedArray, "i"], i1_s : Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
    return a - a[{"i": i1_s - 1}]

make_test("diff", diff, diff_spec)

Correct


## Puzzle 9 - stack

Compute [vstack](https://numpy.org/doc/stable/reference/generated/numpy.vstack.html) - the matrix of two vectors

In [191]:
def stack_spec(a, b, out):
    for i in range(len(out[0])):
        out[0][i] = a[i]
        out[1][i] = b[i]

def stack(a: Int32[NamedArray, "i"], b: Int32[NamedArray, "i"]) -> Int32[NamedArray, "j i"]:
    return where(arange("j", 2) == 1, b, a)


make_test("stack", stack, stack_spec, init_size={"j" : 2})

Correct


## Puzzle 10 - roll

Compute [roll](https://numpy.org/doc/stable/reference/generated/numpy.roll.html) - the vector shifted 1 circular position.

In [192]:
def roll_spec(a, i_s, out):
    for i in range(len(out)):
        if i + 1 < len(out):
            out[i] = a[i + 1]
        else:
            out[i] = a[i + 1 - len(out)]

def roll(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
    return a[{"i": (i_s + 1) % i_s.named_shape["i"]}]


make_test("roll", roll, roll_spec)

Correct


## Puzzle 11 - flip

Compute [flip](https://numpy.org/doc/stable/reference/generated/numpy.flip.html) - the reversed vector

In [193]:
def flip_spec(a, i_s, out):
    for i in range(len(out)):
        out[i] = a[len(out) - i - 1]

def flip(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
    return a[{"i": (-i_s - 1)}]

make_test("flip", flip, flip_spec, add_sizes=["i"])

Correct


## Puzzle 12 - compress


Compute [compress](https://numpy.org/doc/stable/reference/generated/numpy.compress.html) - keep only masked entries (left-aligned).

In [194]:
def compress_spec(g, v, i1_s, i2_s, out):
    j = 0
    for i in range(len(out)):
        out[i] = 0
    for i in range(len(g)):
        if g[i] > 1:
            out[j] = v[i]
            j += 1

def compress(g: Int32[NamedArray, "i1"], v: Int32[NamedArray, "i2"], i1_s:Int32[NamedArray, "i1"], i2_s:Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i2"]:
    # I don't know how to do this one!
    return g

make_test("compress", compress, compress_spec)

Failure


## Puzzle 13 - pad_to


Compute pad_to - eliminate or add 0s to change size of vector.

In [195]:
def pad_to_spec(a, i_s, j_s, out):
    for i in range(len(out)):
        if i < len(a):
            out[i] = a[i]
        else:
            out[i] = 0

def pad_to(a: Int32[NamedArray, "i"], i_s:Int32[NamedArray, "i"], j_s:Int32[NamedArray, "j"])  -> Int32[NamedArray, "j"]:
    return contract("i", a, where(j_s == i_s, 1, 0))


make_test("pad_to", pad_to, pad_to_spec)

Correct


## Puzzle 14 - sequence_mask


Compute [sequence_mask](https://www.tensorflow.org/api_docs/python/tf/sequence_mask) - pad out to length per batch.

In [196]:
# Didn't do
# def sequence_mask_spec(values, length, out):
#     for i in range(len(out)):
#         for j in range(len(out[0])):
#             if j < length[i]:
#                 out[i][j] = values[i][j]
#             else:
#                 out[i][j] = 0

# def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
#     pass


# def constraint_set_length(d):
#     d["length"] = d["length"] % d["values"].shape[1]
#     return d

# make_test("sequence_mask",
#     sequence_mask, sequence_mask_spec, constraint=constraint_set_length
# )

## Puzzle 15 - bincount

Compute [bincount](https://numpy.org/doc/stable/reference/generated/numpy.bincount.html) - count number of times an entry was seen.

In [197]:
def bincount_spec(a, i_s, j1_s, j2_s, out):
    for i in range(len(out)):
        out[i] = 0
    for i in range(len(a)):
        out[a[i]] += 1

def bincount(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"],
             j1_s: Int32[NamedArray, "j1"], j2_s: Int32[NamedArray, "j2"]) -> Int32[NamedArray, "j2"]:
    return contract("i", eye(j1_s, j2_s)[{"j1": a}])


def constraint_set_max(d):
    d["a"] = d["a"] % d["return"].named_shape["j2"]
    return d


make_test("bincount",
    bincount, bincount_spec, constraint=constraint_set_max
)

Correct


## Puzzle 16 - scatter_add

Compute [scatter_add](https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/add.html) - add together values that link to the same location.

In [198]:
def scatter_add_spec(values, link, j1_s, j2_s, out):
    for i in range(len(out)):
        out[i] = 0
    for j in range(len(values)):
        out[link[j]] += values[j]

def scatter_add(values: Int32[NamedArray, "i"], link: Int32[NamedArray,"i"],
                j1_s: Int32[NamedArray, "j1"], j2_s: Int32[NamedArray, "j2"]) -> Int32[NamedArray, "j2"]:
    return contract("i", values, eye(j1_s, j2_s)[{"j1": link}])


def constraint_set_max(d):
    d["link"] = d["link"] % d["return"].named_shape["j2"]
    return d

make_test("scatter_add",
    scatter_add, scatter_add_spec, constraint=constraint_set_max
)

Correct


In [201]:
import inspect
fns = (ones, sum, outer, diag, eye, triu, cumsum, diff, stack, roll, flip,
       compress, pad_to,  bincount, scatter_add)

for fn in fns:
    lines = [l for l in inspect.getsource(fn).split("\n") if not l.strip().startswith("#")]

    if len(lines) > 3:
        print(fn.__name__, len(lines[2]), "(more than 1 line)")
    else:
        print(fn.__name__, len(lines[1]))

ones 22
sum 27
outer 16
diag 38
eye 36
triu 36
cumsum 55
diff 33
stack 43
roll 53
flip 31
compress 12
pad_to 52
bincount 52 (more than 1 line)
scatter_add 63 (more than 1 line)
