<a href="https://colab.research.google.com/github/srush/tensor_puzzles/blob/main/Tensor_Puzzles.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchtyping hypothesis

In [None]:
import typing 
from torchtyping import TensorType
from hypothesis.extra.numpy import arrays
from hypothesis.strategies import integers, tuples, composite, floats
from hypothesis import given
import numpy as np
import torch

size = integers(min_value=1, max_value=5)
tensor = torch.tensor

numpy_to_torch_dtype_dict = {
        bool       : torch.bool,
        np.uint8      : torch.uint8,
        np.int8       : torch.int8,
        np.int16      : torch.int16,
        np.int32      : torch.int32,
        np.int64      : torch.int64,
        np.float16    : torch.float16,
        np.float32    : torch.float32,
        np.float64    : torch.float64,
        np.complex64  : torch.complex64,
        np.complex128 : torch.complex128
    }
torch_to_numpy_dtype_dict = {v:k for k, v in numpy_to_torch_dtype_dict.items()}
    
@composite
def spec(draw, x):
    
    names = set()
    gth = typing.get_type_hints(x)
    for k in gth:
        names.update(gth[k].__metadata__[0]["details"][0].dims)
    names = list(names)
    arr = draw(tuples(*[size for _ in range(len(names))]))
    sizes = dict(zip(names, arr))
    ret = {}
    for k in gth:
        shape = tuple([sizes[d] for d in gth[k].__metadata__[0]["details"][0].dims])
        ret[k] = draw(arrays(shape=shape, 
                             dtype=torch_to_numpy_dtype_dict[gth[k].__metadata__[0]["details"][1].dtype] if len(gth[k].__metadata__[0]["details"]) >= 2 else int,
                             ))
        ret[k][ret[k] > 1000] = 1000
        ret[k][ret[k] < -1000] = -1000
        ret[k] = np.nan_to_num(ret[k], nan=0, neginf=0, posinf=0)
    ret["return"][:] = 0 
    return ret

def make_test(problem, problem_spec):
    @given(spec(problem))
    def test_problem(d):
        out = d["return"].tolist()
        del d["return"]
        problem_spec(*d.values(), out)
        out2 = problem(*map(tensor, d.values()))
        torch.testing.assert_allclose(out2, tensor(out))
    return test_problem

# Tensor Puzzles

A collection of puzzles for learning about tensors and broadcasting. 

## Puzzle 0

Compute the sum of a vector $\sum_i a_i$ in a batch.

In [None]:
def problem0_spec(a, out):
    for i in range(len(out)):
        out[i] = 0
        for j in range(len(a[0])):
            out[i] += a[i][j]

In [None]:
def problem0(a:TensorType["i", "j"]
             )->TensorType["i"]:
    return a.sum(1)

make_test(problem0, problem0_spec)()

## Puzzle 1

Compute the outer-product $a b^\top$ of a batch of vectors. 

In [None]:
def problem1_spec(a, b, out):
    for i in range(len(out)):
        for j in range(len(out[0])):
            for k in range(len(out[0][0])):
                out[i][j][k] = a[i][j] * b[i][k]

In [None]:
def problem1(a:TensorType["i", "j"], 
             b:TensorType["i", "k"]
             )->TensorType["i", "j", "k"]:
    return a[:, :, None] * b[:, None, :] 

make_test(problem1, problem1_spec)()

## Puzzle 2



Compute the sum of arbitrarily grouped values. 

In [None]:
def problem2_spec(groups, values, out):
    j = -1
    for i in range(len(groups)):
        if not groups[i]:
            out[j] += values[i]
        else:
            j += 1 
            out[j] = values[i]


In [None]:
def problem2(groups : TensorType["i", bool], 
             values : TensorType["i"]
             )->TensorType["i"]:
    return values @ torch.eye(groups.shape[0], dtype=int)[groups.cumsum(0)-1] 

make_test(problem2, problem2_spec)()

## Who needs libraries?

## Puzzle 4 - diag

Grab the diagonal elements of a batch of square matrices

In [None]:
def problem4_spec(a, out):
    for i in range(len(a)):
        for j in range(len(a)):
            out[i, j] = a[i, j, j]

In [None]:
def problem4(a : TensorType["i", "j", "j"] 
             )->TensorType["i", "j"]:
    js = torch.arange(a.shape[1])
    return a[:, js, js]
    
make_test(problem4, problem4_spec)()

NameError: ignored

## Puzzle 5 - eyes_like



In [None]:
def problem5_spec(ignore, out):
    for i in range(len(a)):
        out[i, i] = 1

In [None]:
def problem5(ignore : TensorType["i", "i"] 
             )->TensorType["i", "i"]:
    is = torch.arange(a.shape[0])
    return np.where(is[:, None] == is[None, :], 1, 0)

make_test(problem5, problem5_spec)()

NameError: ignored

## Puzzle 6 - triu_like



In [None]:
def problem6_spec(ignore, out):
    for i in range(len(out)):
        for j in range(len(out)):
            if i > j:
                out[i, j] = 1
            else:
                out[i, j] = 0

In [None]:
def problem6(ignore : TensorType["i", "i"] 
             )->TensorType["i", "i"]:
    is = torch.arange(a.shape[0])
    return np.where(is[:, None] > is[None, :], 1, 0)

make_test(problem6, problem6_spec)()

NameError: ignored

## Puzzle 7 - vstack



In [None]:
def problem7_spec(a, b, out):
    for i in range(len(out)):
        out[0, i] = a[i]
        out[1, i] = b[i] 

In [None]:
def problem7(a : TensorType["i"]
             b : TensorType["i"] 
             )->TensorType[2, "i"]:
    return torch.where(torch.arange(2)[:, None] > 0 , 
                       a[None], b[None])
 
make_test(problem7, problem7_spec)()

SyntaxError: ignored

## Puzzle 7 - roll



In [None]:
def problem7_spec(a, b, out):
    for i in range(len(out)):
        out[0, i] = a[i]
        out[1, i] = b[i] 

In [None]:
def problem7(a : TensorType["i"]
             b : TensorType["i"] 
             )->TensorType[2, "i"]:
    return torch.where(torch.arange(2)[:, None] > 0 , 
                       a[None], b[None])
 
make_test(problem7, problem7_spec)()

SyntaxError: ignored