# tensor pre-requisites

## setup

In [67]:
import torch as t
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

%matplotlib inline

In [6]:
t.manual_seed(0xdeadbeef)

class CFG:
    schools = 5
    classes = 3
    students = 20
    exams = 10

data = t.randint(0, 100, (CFG.schools, CFG.classes, CFG.students, CFG.exams))

In [68]:
def draw(frames, figsize=(6, 6), display_inline=True, filename=None):
  def get_frame(ax):
    def f(d):
      ax.clear()
      ax.axis('off')
      ax.margins(0)
      ax.imshow(frames[d], cmap='binary')
    return f

  fig, ax = plt.subplots(figsize=figsize)
  fig.tight_layout()
  ani = FuncAnimation(fig, get_frame(ax), frames=len(frames), interval=50, repeat=False)
  plt.close()
  if display_inline: display(HTML(ani.to_jshtml())) # display inline
  if filename is not None: ani.save(filename, fps=20) # save to disk

### answers

#### tensor manipulation

In [86]:
# find the highest score in the district
def test_district_highest_score(x):
    assert x.allclose(data.max())

# find the best score for each exam in each class in each school
def test_best_score_per_exam(x):
    assert x.allclose(data.max(dim=2).values)

# what class has the best average score in each school
def test_best_class(x):
    assert x.allclose(data.mean(dim=(2, 3), dtype=t.float32).argmax(dim=1))

# compute how many exams each student has failed (score < 50)
def test_failed_exams(x):
    assert x.allclose((data < 50).sum(dim=3))

# what score did the worst student (lowest GPA) got on the their best exam
def test_worst_student_best_exam(x):
    lowest_gpa = data.sum(dim=3).min()
    idx = data.sum(dim=3) == lowest_gpa
    idx.shape
    assert x.allclose(data[idx].max())

#### game of life

In [104]:
def _count_neighbors(world):
    kernel = t.ones(1, 1, 3, 3)
    kernel[0, 0, 1, 1] = 0
    return F.conv2d(world[None, None], kernel, padding=1)[0, 0]

def test_count_neighbors(f, world):
    assert f(world).allclose(_count_neighbors(world))

def _next_step(world):
    neighbors = _count_neighbors(world)
    next_generation = (neighbors == 3) | ((world == 1) & (neighbors == 2))
    return next_generation.float()

def test_next_step(f, world):
    assert f(world).allclose(_next_step(world))

## tensor manipulation

We will be working with a 4d tensor `data` representing each schools in a district of shape `[school, class, student, exam]`

For example:
`data[0, 5, 2, 1]`

would tell us what score the 3rd student in the 6th class of the first school scored on their second exam

In [87]:
# find the highest score in the district
# <STUDENT JOB>
answer = data.max()

test_district_highest_score(answer)

In [92]:
# find the best score for each exam in each class in each school
# <STUDENT JOB>
answer = data.max(dim=2).values

test_best_score_per_exam(answer)

In [93]:
# what class has the best average score in each school
# <STUDENT JOB>
answer = data.mean(dim=(2, 3), dtype=t.float32).argmax(dim=1)

test_best_class(answer)

In [94]:
# compute how many exams each student has failed (score < 50)
# <STUDENT JOB>
answer = (data < 50).sum(dim=3)

test_failed_exams(answer)

In [95]:
# what score did the worst student (lowest GPA) got on the their best exam
# <STUDENT JOB>
lowest_gpa = data.sum(dim=3).min()
idx = data.sum(dim=3) == lowest_gpa
idx.shape
answer = data[idx].max()

test_worst_student_best_exam(answer)

## game of life

https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life

we can represent the world as a 2d tensor of shape `[height, width]` where `0.` means dead, and `1.` means alive

In [57]:
world = t.tensor([
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 1, 0, 0],
    [0, 0, 0, 1, 1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0]], dtype=t.float32)

world.shape

torch.Size([8, 8])

game of life is a celullar automata obeying the following rules:
- Any live cell with fewer than two live neighbors dies, as if by underpopulation.
- Any live cell with two or three live neighbors lives on to the next generation.
- Any live cell with more than three live neighbors dies, as if by overpopulation.
- Any dead cell with exactly three live neighbors becomes a live cell, as if by reproduction.

In [105]:
# find how many neighbors each cell has
def count_neighbors(world):
    # <STUDENT JOB>
    kernel = t.ones(1, 1, 3, 3)
    kernel[0, 0, 1, 1] = 0
    return F.conv2d(world[None, None], kernel, padding=1)[0, 0]

test_count_neighbors(count_neighbors, world)

Note: boolean operations between tensors have lower priorities so they require parenthesis.

In [106]:
# update the world according to the rules of the game of life
def next_step(world):
    # <STUDENT JOB>
    neighbors = count_neighbors(world)
    next_generation = (neighbors == 3) | ((world == 1) & (neighbors == 2))
    return next_generation.float()

test_next_step(next_step, world)

In [107]:
def play(world, epochs=10):
    frames = [world]
    for _ in range(epochs):
        world = next_step(world)
        frames.append(world.clone())
    draw(frames)

play(world)