# PyTorch Challenges

[Sasha Rush](https://twitter.com/srush_nlp) compiled a set of [16 Tensor mini-puzzles](https://github.com/srush/Tensor-Puzzles) that involve reasoning about broadcasting in a constrained setting: people are allowed to use only a single PyTorch function: `torch.arange`. Can you do it?

Here, I've extended his list to 27 puzzles! 

**Rules**

- Each puzzle needs to be solved in 1 line (<80 columns) of code.
- You are allowed @, arithmetic, comparison, shape, any indexing (e.g. `a[:j], a[:, None], a[arange(10)]`), and previous puzzle functions.
- To start off, we give you an implementation for the `torch.arange` function.

**Anti-Rules**
- Nothing else. No `.view, .sum, .take, .squeeze, .tensor`.
- No cheating. Stackoverflow is great, but this is about first-principles.
- Hint... these puzzles are mostly about [Broadcasting](https://pytorch.org/docs/master/notes/broadcasting.html). Make sure you understand this rule, which is a key concept for dealing with n-dimensional arrays.

🐶🐶🐶 After you convince yourself your code is correct, run the cell to test it. If the test succeeds, you will get a puppy 🐶🐶🐶.

List of puzzles:

1. [where](#1\)-where)
2. [ones](#2\)-ones)
3. [sum](#3\)-sum)
4. [outer](#4\)-outer)
5. [diag](#5\)-diag)
6. [eye](#6\)-eye)
7. [triu](#7\)-triu)
8. [cumsum](#8\)-cumsum)
9. [diff](#9\)-diff)
10. [vstack](#10\)-vstack)
11. [roll](#11\)-roll)
12. [flip](#12\)-flip)
13. [compress](#13\)-compress)
14. [pad_to](#14\)-pad_to)
15. [sequence_mask](#15\)-sequence_mask)
16. [bincount](#16\)-bincount)
17. [scatter_add](#17\)-scatter_add)
18. [flatten](#18\)-flatten)
19. [linspace](#19\)-linspace)
20. [heaviside](#20\)-heaviside)
21. [hstack](#21\)-hstack)
22. [view](#22\)-view-\(1d-to-2d\))
23. [repeat](#23\)-repeat-\(1d\))
24. [repeat_interleave](#24\)-repeat_interleave-\(1d\))
25. [chunk](#25\)-chunk)
26. [nonzero](#26\)-nonzero)
27. [bucketize](#27\)-bucketize)

---

## Setup

In [None]:
!pip install -qqq torchtyping hypothesis pytest

In [None]:
import torch
from spec import make_test, run_test, TT

---

### arange

This one is given! Think about it as a "for-loop"

In [None]:
def arange(i: int):
    return torch.arange(i)

arange(6)

### 1) where
https://numpy.org/doc/stable/reference/generated/numpy.where.html

In [None]:
def where_spec(q, a, b, out):
    for i in range(len(out)):
        out[i] = a[i] if q[i] else b[i]

def where(q: TT["i", bool], a: TT["i"], b: TT["i"]) -> TT["i"]:
    raise NotImplementedError

run_test(make_test("where", where, where_spec))

### 2) ones
https://numpy.org/doc/stable/reference/generated/numpy.ones.html

In [None]:
def ones_spec(out):
    for i in range(len(out)):
        out[i] = 1

def ones(i: int) -> TT["i"]:
    raise NotImplementedError

run_test(make_test("one", ones, ones_spec, add_sizes=["i"]))

### 3) sum
https://numpy.org/doc/stable/reference/generated/numpy.sum.html

In [None]:
def sum_spec(a, out):
    out[0] = 0
    for i in range(len(a)):
        out[0] += a[i]

def sum(a: TT["i"]) -> TT[1]:
    raise NotImplementedError

run_test(make_test("sum", sum, sum_spec))

### 4) outer
https://numpy.org/doc/stable/reference/generated/numpy.outer.html

In [None]:
def outer_spec(a, b, out):
    for i in range(len(out)):
        for j in range(len(out[0])):
            out[i][j] = a[i] * b[j]

def outer(a: TT["i"], b: TT["j"]) -> TT["i", "j"]:
    raise NotImplementedError

run_test(make_test("outer", outer, outer_spec))

### 5) diag
https://numpy.org/doc/stable/reference/generated/numpy.diag.html

In [None]:
def diag_spec(a, out):
    for i in range(len(a)):
        out[i] = a[i][i]
        
def diag(a: TT["i", "i"]) -> TT["i"]:
    raise NotImplementedError

run_test(make_test("diag", diag, diag_spec))

### 6) eye
https://numpy.org/doc/stable/reference/generated/numpy.eye.html

In [None]:
def eye_spec(out):
    for i in range(len(out)):
        out[i][i] = 1
        
def eye(j: int) -> TT["j", "j"]:
    raise NotImplementedError

run_test(make_test("eye", eye, eye_spec, add_sizes=["j"]))

### 7) triu
https://numpy.org/doc/stable/reference/generated/numpy.triu.html

In [None]:
def triu_spec(out):
    for i in range(len(out)):
        for j in range(len(out)):
            if i <= j:
                out[i][j] = 1
            else:
                out[i][j] = 0
                
def triu(j: int) -> TT["j", "j"]:
    raise NotImplementedError

run_test(make_test("triu", triu, triu_spec, add_sizes=["j"]))

### 8) cumsum
https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html

In [None]:
def cumsum_spec(a, out):
    total = 0
    for i in range(len(out)):
        out[i] = total + a[i]
        total += a[i]

def cumsum(a: TT["i"]) -> TT["i"]:
    raise NotImplementedError

run_test(make_test("cumsum", cumsum, cumsum_spec))

### 9) diff
https://numpy.org/doc/stable/reference/generated/numpy.diff.html

In [None]:
def diff_spec(a, out):
    out[0] = a[0]
    for i in range(1, len(out)):
        out[i] = a[i] - a[i - 1]

def diff(a: TT["i"], i: int) -> TT["i"]:
    raise NotImplementedError

run_test(make_test("diff", diff, diff_spec, add_sizes=["i"]))

### 10) vstack
https://numpy.org/doc/stable/reference/generated/numpy.vstack.html

In [None]:
def vstack_spec(a, b, out):
    for i in range(len(out[0])):
        out[0][i] = a[i]
        out[1][i] = b[i]

def vstack(a: TT["i"], b: TT["i"]) -> TT[2, "i"]:
    raise NotImplementedError

run_test(make_test("vstack", vstack, vstack_spec))

### 11) roll
https://numpy.org/doc/stable/reference/generated/numpy.roll.html

In [None]:
def roll_spec(a, out):
    for i in range(len(out)):
        if i + 1 < len(out):
            out[i] = a[i + 1]
        else:
            out[i] = a[i + 1 - len(out)]
            
def roll(a: TT["i"], i: int) -> TT["i"]:
    raise NotImplementedError

run_test(make_test("roll", roll, roll_spec, add_sizes=["i"]))

### 12) flip
https://numpy.org/doc/stable/reference/generated/numpy.flip.html

In [None]:
def flip_spec(a, out):
    for i in range(len(out)):
        out[i] = a[len(out) - i - 1]
        
def flip(a: TT["i"], i: int) -> TT["i"]:
    raise NotImplementedError

run_test(make_test("flip", flip, flip_spec, add_sizes=["i"]))

### 13) compress
https://numpy.org/doc/stable/reference/generated/numpy.compress.html

In [None]:
def compress_spec(g, v, out):
    j = 0
    for i in range(len(g)):
        if g[i]:
            out[j] = v[i]
            j += 1
            
def compress(g: TT["i", bool], v: TT["i"], i:int) -> TT["i"]:
    raise NotImplementedError

run_test(make_test("compress", compress, compress_spec, add_sizes=["i"]))

### 14) pad_to

https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_sequence.html?highlight=pad#torch.nn.utils.rnn.pad_sequence

In [None]:
def pad_to_spec(a, out):
    for i in range(min(len(out), len(a))):
        out[i] = a[i]

def pad_to(a: TT["i"], i: int, j: int) -> TT["j"]:
    raise NotImplementedError

run_test(make_test("pad_to", pad_to, pad_to_spec, add_sizes=["i", "j"]))

### 15) sequence_mask
https://www.tensorflow.org/api_docs/python/tf/sequence_mask

In [None]:
def sequence_mask_spec(values, length, out):
    for i in range(len(out)):
        for j in range(len(out[0])):
            if j < length[i]:
                out[i][j] = values[i][j]
            else:
                out[i][j] = 0

def constraint_set_length(d, sizes=None):
    d["length"] = d["length"] % d["values"].shape[1]
    return d
    
def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
    raise NotImplementedError

run_test(make_test("sequence_mask",
    sequence_mask, sequence_mask_spec, constraint=constraint_set_length
))

### 16) bincount
https://numpy.org/doc/stable/reference/generated/numpy.bincount.html

In [None]:
def bincount_spec(a, out):
    for i in range(len(a)):
        out[a[i]] += 1
        
def constraint_set_max(d, sizes=None):
    d["a"] = d["a"] % d["return"].shape[0]
    return d
        
def bincount(a: TT["i"], j: int) -> TT["j"]:
    raise NotImplementedError

run_test(make_test("bincount",
    bincount, bincount_spec, add_sizes=["j"], constraint=constraint_set_max
))

### 17) scatter_add
https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/add.html

In [None]:
def scatter_add_spec(values, link, out):
    for j in range(len(values)):
        out[link[j]] += values[j]

def constraint_set_max(d, sizes=None):
    d["link"] = d["link"] % d["return"].shape[0]
    return d

def scatter_add(values: TT["i"], link: TT["i"], j: int) -> TT["j"]:
    raise NotImplementedError


run_test(make_test("scatter_add",
    scatter_add, scatter_add_spec, add_sizes=["j"], constraint=constraint_set_max
))

### 18) flatten

https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flatten.html

In [None]:
def flatten_spec(a, out):
    k = 0
    for i in range(len(a)):
        for j in range(len(a[0])):
            out[k] = a[i][j]
            k += 1

def flatten(a: TT["i", "j"], i:int, j:int) -> TT["i * j"]:
    raise NotImplementedError

run_test(make_test("flatten", flatten, flatten_spec, add_sizes=["i", "j"]))

### 19) linspace

https://numpy.org/doc/stable/reference/generated/numpy.linspace.html

In [None]:
def linspace_spec(i, j, out):
    for k in range(len(out)):
        out[k] = float(i + (j - i) * k / max(1, len(out) - 1))

def linspace(i: TT[1], j: TT[1], n: int) -> TT["n", float]:
    raise NotImplementedError

run_test(make_test("linspace", linspace, linspace_spec, add_sizes=["n"]))

### 20) heaviside

https://numpy.org/doc/stable/reference/generated/numpy.heaviside.html

In [None]:
def heaviside_spec(a, b, out):
    for k in range(len(out)):
        if a[k] == 0:
            out[k] = b[k]
        else:
            out[k] = int(a[k] > 0)

def heaviside(a: TT["i"], b: TT["i"]) -> TT["i"]:
    raise NotImplementedError

run_test(make_test("heaviside", heaviside, heaviside_spec))

### 21) hstack

https://numpy.org/doc/stable/reference/generated/numpy.hstack.html

In [None]:
def hstack_spec(a, b, out):
    for i in range(len(out)):
        out[i][0] = a[i]
        out[i][1] = b[i]
            
def hstack(a: TT["i"], b: TT["i"]) -> TT["i", 2]:
    raise NotImplementedError

run_test(make_test("hstack", hstack, hstack_spec))

---

No more puppies from now on... For now, check with the examples shown in the docs.

### 22) view (1d to 2d)

https://pytorch.org/docs/stable/generated/torch.Tensor.view.html

In [None]:
def view(a: TT["i * j"], i: int, j: int) -> TT["i", "j"]:
    raise NotImplementedError

### 23) repeat (1d)

https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html

In [None]:
def repeat(a: TT["i"], d: int) -> TT["d"]:
    raise NotImplementedError

### 24) repeat_interleave (1d)

https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html

In [None]:
def repeat_interleave(a: TT["i"], d: int) -> TT["d"]:
    raise NotImplementedError

### 25) chunk
https://pytorch.org/docs/stable/generated/torch.chunk.html

In [None]:
def chunk(a: TT["i"], c: int) -> TT["c", "i // c"]:
    raise NotImplementedError

### 26) nonzero
https://pytorch.org/docs/stable/generated/torch.nonzero.html

In [None]:
def nonzero(a: TT["i","j"], i: int, j: int) -> TT["k", 2]:
    raise NotImplementedError

### 27) bucketize
https://pytorch.org/docs/stable/generated/torch.bucketize.html

In [None]:
def bucketize(v: TT["i"], boundaries: TT["j"]) -> TT["i"]:
    raise NotImplementedError