# PyTorch Challenges

These set of challenges are concerned about broadcasting, one of the key concepts when dealing with tensors.

[Sasha Rush](https://twitter.com/srush_nlp) compiled a set of [16 Tensor mini-puzzles](https://github.com/srush/Tensor-Puzzles) that involve reasoning about broadcasting in a constrained setting: people are allowed to use only a single PyTorch function: `torch.arange`. Can you do it?

Here, I've extended his list to 26 puzzles! 

**Rules**

- Each puzzle needs to be solved in 1 line (<80 columns) of code.
- You are allowed @, arithmetic, comparison, shape, any indexing (e.g. `a[:j], a[:, None], a[arange(10)]`), and previous puzzle functions.
- To start off, we give you an implementation for the `torch.arange` function.

**Anti-Rules**
- Nothing else. No `.view, .sum, .take, .squeeze, .tensor`.
- No cheating. Stackoverflow is great, but this is about first-principles.
- Hint... these puzzles are mostly about [Broadcasting](https://pytorch.org/docs/master/notes/broadcasting.html). Make sure you understand this rule.


---

In [None]:
%load_ext autoreload
%autoreload 2

In [169]:
import torch
from spec import make_test, run_test, TT

### arange

This is given for free! Think about it as a "for-loop"

In [170]:
def arange(i: int):
    return torch.arange(i)

arange(6)

tensor([0, 1, 2, 3, 4, 5])

### where

In [171]:
def where(q, a, b):
    return q * a + (~q) * b

where(arange(4) % 2 == 0, arange(4), -1)

tensor([ 0, -1,  2, -1])

### ones

In [172]:
def ones(i: int):
    return where(arange(i) >= 0, 1, 0)

ones(4)

tensor([1, 1, 1, 1])

### sum

In [173]:
def sum(a: torch.Tensor):
    return ones(a.shape[0]) @ a

sum(arange(4))

tensor(6)

### outer

In [174]:
def outer(a: torch.Tensor, b: torch.Tensor):
    return a[:, None] * b[None, :]

outer(arange(4), ones(3))

tensor([[0, 0, 0],
        [1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])

### diag

In [175]:
def diag(a: torch.Tensor):
    return a[arange(a.shape[0]), arange(a.shape[0])]

diag(outer(arange(4), ones(4)))

tensor([0, 1, 2, 3])

### eye

In [176]:
def eye(j: int):
    return (arange(j)[:, None] == arange(j)[None, :]) * 1

eye(4)

tensor([[1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]])

### triu

In [177]:
def triu(j: int):
    return (arange(j)[:,None] <= arange(j))*1

triu(4)

tensor([[1, 1, 1, 1],
        [0, 1, 1, 1],
        [0, 0, 1, 1],
        [0, 0, 0, 1]])

### cumsum

In [178]:
def cumsum(a: torch.Tensor):
    return (outer(ones(a.shape[0]), a) @ triu(a.shape[0]))[0]

cumsum(torch.arange(4))

tensor([0, 1, 3, 6])

### diff

In [179]:
def diff(a: torch.Tensor, i: int):
    return a - a[where(arange(i) > 0, arange(i)-1, 0)] + (a*(arange(i) <= 0))

diff(arange(4), 4)

tensor([0, 1, 1, 1])

### vstack

In [180]:
def vstack(a: torch.Tensor, b: torch.Tensor):
    return a * (1-arange(2)[:, None]) + b * arange(2)[:, None]

vstack(arange(4), ones(4))

tensor([[0, 1, 2, 3],
        [1, 1, 1, 1]])

### roll

In [181]:
def roll(a: torch.Tensor, i: int):
    return a[(arange(i) + 1) * ((arange(i) + 1) < i)]

roll(arange(4), 4)

tensor([1, 2, 3, 0])

### flip

In [182]:
def flip(a: torch.Tensor, i: int):
    return a[i - arange(i) - 1]

flip(arange(4), 4)

tensor([3, 2, 1, 0])

### compress

In [183]:
def compress(g: torch.Tensor, v: torch.Tensor, i: int):
    return sum(eye(i)[:sum(g*1)] * outer(v[g], ones(i)))

compress(torch.tensor([False, True, True]), arange(3), 3)

tensor([1, 2, 0])

### pad_to

In [184]:
def pad_to(a: torch.Tensor, i: int, j: int):
    return sum((arange(i)[:, None] == arange(j)[None, :]) * a[:, None])

pad_to(arange(3), 3, 5)

tensor([0, 1, 2, 0, 0])

### sequence_mask

In [185]:
def sequence_mask(values: torch.Tensor, length: torch.Tensor):
    return values * (length[:, None] > arange(values.shape[-1])[None, :])

sequence_mask(outer(ones(4), ones(3)), torch.tensor([2,2,1,3]))

tensor([[1, 1, 0],
        [1, 1, 0],
        [1, 0, 0],
        [1, 1, 1]])

### bincount

In [186]:
def bincount(a: torch.Tensor, j: int):
    return ones(len(a)) @ ((a[:, None] == arange(j)[None, :]) * 1)

bincount(torch.tensor([2, 1, 3, 3, 1, 2, 2, 2, 1, 0]), 4)

tensor([1, 3, 4, 2])

### scatter_add

In [187]:
def scatter_add(values: torch.Tensor, link: torch.Tensor, j: int):
    return sum((link[:, None] == arange(j)[None, :]) * outer(values, ones(j)))

scatter_add(torch.tensor([5,1,7,2,3,2,1,3]), torch.tensor([0,0,1,0,2,2,3,3]), 4)

tensor([8, 7, 5, 4])

### flatten

In [189]:
def flatten(a: torch.Tensor, i:int, j:int):
    return a[outer(ones(i), ones(j)) == 1]

flatten(arange(16).view(4, 4), 4, 4)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])

### linspace

In [190]:
def linspace(i: float, j: float, n: int):
    return i + (j - i) * arange(n) / max(1, (n - 1))

linspace(0, 1, 10)

tensor([0.0000, 0.1111, 0.2222, 0.3333, 0.4444, 0.5556, 0.6667, 0.7778, 0.8889,
        1.0000])

### heaviside

In [191]:
def heaviside(a: torch.Tensor, b: torch.Tensor):
    return (a > 0) + (a == 0) * b

heaviside(torch.tensor([1, 0, -2]), torch.randn(3))

tensor([ 1.0000, -2.6444,  0.0000])

### hstack

In [192]:
def hstack(a: torch.Tensor, b: torch.Tensor):
    return a[:,None] * eye(2)[0] + b[:,None] * eye(2)[1]

hstack(arange(3), ones(3))

tensor([[0, 1],
        [1, 1],
        [2, 1]])

### view (1d to 2d)

In [193]:
def view(a: torch.Tensor, i: int, j: int):
    return a[(j * arange(i)[:,None] + arange(j)[None]) % len(a)][:i, :j]

view(arange(6), 3, 2)

tensor([[0, 1],
        [2, 3],
        [4, 5]])

### repeat (1d)

In [194]:
def repeat(a: torch.Tensor, d: int):
    return (ones(d)[:, None] * a)[outer(ones(d), ones(len(a))) == 1]

repeat(arange(5),  3)

tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4])

### repeat_interleave (1d)

In [195]:
def repeat_interleave(a: torch.Tensor, d: int):
    return (ones(d)[:, None] * a).T[outer(ones(len(a)), ones(d)) == 1]

repeat_interleave(arange(5), 3)

tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])

### chunk

In [198]:
def chunk(a: torch.Tensor, c: int):
    return list(view(a, c, len(a)//c))

chunk(torch.arange(12), 6)

[tensor([0, 1]),
 tensor([2, 3]),
 tensor([4, 5]),
 tensor([6, 7]),
 tensor([8, 9]),
 tensor([10, 11])]

### nonzero

In [200]:
def nonzero(a: torch.Tensor, i: int, j: int):
    return hstack(outer(arange(i),ones(j))[a!=0],outer(ones(i),arange(j))[a!=0])

nonzero(eye(3), 3, 3)

tensor([[0, 0],
        [1, 1],
        [2, 2]])

### bucketize

In [201]:
def bucketize(v: torch.Tensor, boundaries: torch.Tensor):
    return sum((v[:,None] > boundaries[None, :]).T * 1)

bucketize(torch.tensor([3, 6, 9]), torch.tensor([1, 3, 5, 7, 9]))

tensor([1, 3, 4])