# [0.4] - Build Your Own Backpropagation Framework (exercises)

> **ARENA [Streamlit Page](https://arena-chapter0-fundamentals.streamlit.app/04_[0.4]_Backprop)**
>
> **Colab: [exercises](https://colab.research.google.com/github/callummcdougall/ARENA_3.0/blob/main/chapter0_fundamentals/exercises/part4_backprop/0.4_Backprop_exercises.ipynb?t=20250910) | [solutions](https://colab.research.google.com/github/callummcdougall/ARENA_3.0/blob/main/chapter0_fundamentals/exercises/part4_backprop/0.4_Backprop_solutions.ipynb?t=20250910)**

Please send any problems / bugs on the `#errata` channel in the [Slack group](https://join.slack.com/t/arena-uk/shared_invite/zt-3afdmdhye-Mdb3Sv~ss_V_mEaXEbkABA), and ask any questions on the dedicated channels for this chapter of material.

You can collapse each section so only the headers are visible, by clicking the arrow symbol on the left hand side of the markdown header cells.

Links to all other chapters: [(0) Fundamentals](https://arena-chapter0-fundamentals.streamlit.app/), [(1) Transformer Interpretability](https://arena-chapter1-transformer-interp.streamlit.app/), [(2) RL](https://arena-chapter2-rl.streamlit.app/).

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/headers/header-04.png" width="350">

# Introduction

Today you're going to build your very own system that can run the backpropagation algorithm in essentially the same way as PyTorch does. By the end of the day, you'll be able to train a multi-layer perceptron neural network, using your own backprop system!

The main differences between the full PyTorch and our version are:

* We will focus on CPU only, as all the ideas are the same on GPU.
* We will use NumPy arrays internally instead of ATen, the C++ array type used by PyTorch. Backpropagation works independently of the array type.
* A real `torch.Tensor` has about 700 fields and methods. We will only implement a subset that are particularly instructional and/or necessary to train the MLP.

Note - for today, I'd lean a lot more towards being willing to read the solutions, and even move on from some of them if you don't fully understand them (especially in the first half of section 3). The low-level messy implementation details for today are much less important than the high-level conceptual takeaways.

For a lecture on the material today, which provides some high-level understanding before you dive into the material, watch the video below:

<iframe width="540" height="304" src="https://www.youtube.com/embed/-24lS-kk5I0" frameborder="0" allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>

## Content & Learning Objectives

### 1️⃣ Introduction to backprop

This takes you through what a **computational graph** is, and the basics of how gradients can be backpropagated through such a graph. You'll also implement the backwards versions of some basic functions: if we have tensors `output = func(input)`, then the backward function of `func` can calculate the grad of `input` as a function of the grad of `output`.

> ##### Learning Objectives
>
> * Understand what a computational graph is, and how it can be used to calculate gradients.
> * Start to implement backwards versions of some basic functions.

### 2️⃣ Autograd

This section goes into more detail on the backpropagation methodology. In order to find the `grad` of each tensor in a computational graph, we first have to perform a **topological sort** of the tensors in the graph, so that each time we try to calculate `tensor.grad`, we've already computed all the other gradients which are used in this calculation. We end this section by writing a `backprop` function, which works just like the `tensor.backward()` method you're already used to in PyTorch.

> ##### Learning Objectives
>
> * Perform a topological sort of a computational graph (and understand why this is important).
> * Implement a the `backprop` function, to calculate and store gradients for all tensors in a computational graph.

### 3️⃣ Training on MNIST from scratch

In this section, we build your own equivalents of `torch.nn` features like `nn.Parameter`, `nn.Module`, and `nn.Linear`. We can then use these to build our own neural network to classify MINST data.

This completes the chain which starts at basic numpy arrays, and ends with us being able to build essentially any neural network architecture we want!

> ##### Learning Objectives
>
> * Implement more forward and backward functions, including for indexing, summing, and matrix multiplication
> * Learn how to build higher-level abstractions like parameters and modules on top of individual functions and tensors
> * Complete the process of building up a neural network from scratch and training it via gradient descent.

### 4️⃣ Bonus

A few bonus exercises are suggested, for pushing your understanding of backpropagation further.

## Setup code

In [1]:
import os
import sys
from pathlib import Path

IN_COLAB = "google.colab" in sys.modules

chapter = "chapter0_fundamentals"
repo = "ARENA_3.0"
branch = "main"

# Install dependencies
try:
    import jaxtyping
except:
    %pip install einops jaxtyping

# Get root directory, handling 3 different cases: (1) Colab, (2) notebook not in ARENA repo, (3) notebook in ARENA repo
root = (
    "/content"
    if IN_COLAB
    else "/root"
    if repo not in os.getcwd()
    else str(next(p for p in Path.cwd().parents if p.name == repo))
)

if Path(root).exists() and not Path(f"{root}/{chapter}").exists():
    if not IN_COLAB:
        !sudo apt-get install unzip
        %pip install jupyter ipython --upgrade

    if not os.path.exists(f"{root}/{chapter}"):
        !wget -P {root} https://github.com/callummcdougall/ARENA_3.0/archive/refs/heads/{branch}.zip
        !unzip {root}/{branch}.zip '{repo}-{branch}/{chapter}/exercises/*' -d {root}
        !mv {root}/{repo}-{branch}/{chapter} {root}/{chapter}
        !rm {root}/{branch}.zip
        !rmdir {root}/{repo}-{branch}


if f"{root}/{chapter}/exercises" not in sys.path:
    sys.path.append(f"{root}/{chapter}/exercises")

os.chdir(f"{root}/{chapter}/exercises")

In [2]:
import os
import re
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Iterable, Iterator

import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm

Arr = np.ndarray
grad_tracking_enabled = True

# Make sure exercises are in the path
chapter = "chapter0_fundamentals"
section = "part4_backprop"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
    sys.path.append(str(exercises_dir))


import part4_backprop.tests as tests
from part4_backprop.utils import get_mnist, visualize
from plotly_utils import line

<details>
<summary>Help - I get a NumPy-related error</summary>

This is an annoying colab-related issue which I haven't been able to find a satisfying fix for. If you restart runtime (but don't delete runtime), and run just the imports cell above again (but not the `%pip install` cell), the problem should go away.
</details>

# 1️⃣ Introduction to backprop

> ##### Learning Objectives
>
> * Understand what a computational graph is, and how it can be used to calculate gradients.
> * Start to implement backwards versions of some basic functions.

## Reading

* [Calculus on Computational Graphs: Backpropagation (Chris Olah)](https://colah.github.io/posts/2015-08-Backprop/)

## Computing Gradients with Backpropagation

This section will briefly review the backpropagation algorithm, but focus mainly on the concrete implementation in software.

To train a neural network, we want to know how the loss would change if we slightly adjust one of the learnable parameters.

One obvious and straightforward way to do this would be just to add a small value  to the parameter, and run the forward pass again. This is called finite differences, and the main issue is we need to run a forward pass for every single parameter that we want to adjust. This method is infeasible for large networks, but it's important to know as a way of sanity checking other methods.

A second obvious way is to write out the function for the entire network, and then symbolically take the gradient to obtain a symbolic expression for the gradient. This also works and is another thing to check against, but the expression gets quite complicated.

Suppose that you have some **computational graph**, and you want to determine the derivative of the some scalar loss L with respect to NumPy arrays a, b, and c:

<img src="https://raw.githubusercontent.com/callummcdougall/Fundamentals/main/images/abc_de_L.png" width="400">

This graph corresponds to the following Python:

```python
d = a * b
e = b + c
L = d + e
```

The goal of our system is that users can write ordinary looking Python code like this and have all the book-keeping needed to perform backpropagation happen behind the scenes. To do this, we're going to wrap each array and each function in special objects from our library that do the usual thing plus build up this graph structure that we need.

### Backward Functions

We've drawn our computation graph from left to right and the arrows pointing to the right, so that in the forward pass, boxes to the right depend on boxes to the left. In the backwards pass, the opposite is true: the gradient of boxes on the left depends on the gradient of boxes on the right.

If we want to compute the derivative of $L$ wrt all other variables (as was described in the reading), we should traverse the graph from right to left. Each time we encounter an instance of function application, we can use the chain rule from calculus to proceed one step further to the left. For example, if we have $d = a \times b$, then:

$$
\frac{dL}{da} = \frac{dL}{dd}\times \frac{dd}{da} = \frac{dL}{dd}\times b
$$

Suppose we are working from right to left, trying to calculate $\frac{dL}{da}$. If we already know the values of the variables $a$, $b$ and $d$, as well as the value of $\frac{dL}{dd}$, then we can use the following function to find $\frac{dL}{da}$:

$$
F\left(a, b, d, \frac{dL}{dd}\right) = \frac{dL}{dd}\times b
$$

and we can do something similar when trying to calculate $\frac{dL}{db}$.

In other words, we can take the **"forward function"** $(a, b) \to a \cdot b$, and for each of its parameters, we can define an associated **"backwards function"** which tells us how to compute the gradient wrt this argument using only known quantities as inputs.

Ignoring issues of unbroadcasting (which we'll cover later), we could write the backward with respect to the first argument as:

```python
def multiply_back(grad_out, out, a, b):
    '''
    Inputs:
        grad_out = dL/d(out)
        out = a * b

    Returns:
        dL/da
    '''
    return grad_out * b
```

where `grad_out` is the gradient of the loss of the function with respect to the output (i.e. $\frac{dL}{dd}$), `out` is the output of the function (i.e. $d$), and `a` and `b` are our inputs.

### Topological Ordering

When we're actually doing backprop, how do we guarantee that we'll always know the value of our backwards functions' inputs? For instance, in the example above we couldn't have computed $\frac{dL}{da}$ without first knowing $\frac{dL}{dd}$.

The answer is that we sort all our nodes using an algorithm called [topological sorting](https://en.wikipedia.org/wiki/Topological_sorting), and then do our computations in this order. After each computation, we store the gradients in our nodes for use in subsequent calculations.

When described in terms of the diagram above, topological sort can be thought of as an ordering of nodes from right to left. Crucially, this sorting has the following property: if there is a directed path in the computational graph going from node `x` to node `y`, then `x` must follow `y` in the sorting.

There are many ways of proving that a cycle-free directed graph contains a topological ordering. You can try and prove this for yourself, or click on the expander below to reveal the outline of a simple proof.

<details>
<summary>Click to reveal proof</summary>

We can prove by induction on the number of nodes $N$.
    
If $N=1$, the problem is trivial.

If $N>1$, then pick any node, and follow the arrows until you reach a node with no directed arrows going out of it. Such a node must exist, or else you would be following the arrows forever, and you'd eventually return to a node you previously visited, but this would be a cycle, which is a contradiction. Once you've found this "root node", you can put it first in your topological ordering, then remove it from the graph and apply the topological sort on the subgraph created by removing this node. By induction, your topological sorting algorithm on this smaller graph should return a valid ordering. If you append the root node to the start of this ordering, you have a topological ordering for the whole graph.
</details>

A quick note on some potentially confusing terminology. We will refer to the "end node" as the **root node**, and the "starting nodes" as **leaf nodes**. For instance, in the diagram at the top of the section, the left nodes `a`, `b` and `c` are the leaf nodes, and `L` is the root node. This might seem odd given it makes the leaf nodes come before the root nodes, but the reason is as follows: *when we're doing the backpropagation algorithm, we start at `L` and work our way back through the graph*. So, by our notation, we start at the root node and work our way out to the leaves.

Another important piece of terminology here is **parent node**. This means the same thing as it does in most other contexts - the parents of node `x` are all the nodes `y` with connections `y -> x` (so in the diagram, `L`'s parents are `d` and `e`).

<details>
<summary>Question - can you think of a reason it might be important for a node to store a list of all of its parent nodes?</summary>

During backprop, we're moving from right to left in the diagram. If a node doesn't store its parent, then there will be no way to get access to that parent node during backprop, so we can't propagate gradients to it.
</details>

The very first node in our topological sort will be $L$, the root node.

### Backpropagation

After all this setup, the backpropagation mechanism becomes pretty straightforward. We sort the nodes topologically, then we iterate over them and call each backward function exactly once in order to accumulate the gradients at each node.

It's important that the grads be accumulated instead of overwritten in a case like value $b$ which has two outgoing edges, since $\frac{dL}{db}$ will then be the sum of two terms. Since addition is commutative it doesn't matter whether we `backward()` the Mul or the Add that depend on $b$ first.

During backpropagation, for each forward function in our computational graph we need to find the partial derivative of the output with respect to each of its inputs. Each partial is then multiplied by the gradient of the loss with respect to the forward functions output (`grad_out`) to find the gradient of the loss with respect to each input. We'll handle these calculations using backward functions.

## Backward function of log

First, we'll write the backward function for `x -> out = log(x)`. This should be a function which, when fed the values `x, out, grad_out = dL/d(out)` returns the value of `dL/dx` just from this particular computational path.

<img src="https://raw.githubusercontent.com/callummcdougall/Fundamentals/main/images/x_log_out.png" width="400">

Note - it might seem strange at first why we need `x` and `out` to be inputs, `out` can be calculated directly from `x`. The answer is that sometimes it is computationally cheaper to express the derivative in terms of `out` than in terms of `x`.

<details>
<summary>Question - can you think of an example function where it would be computationally cheaper to use 'out' than to use 'x'?</summary>

The most obvious answer is the exponential function, `out = e ^ x`. Here, the gradient `d(out)/dx` is equal to `out`. We'll see this when we implement a backward version of `torch.exp` later today.
</details>

### Exercise - implement `log_back`

> ```yaml
> Difficulty: 🔴🔴⚪⚪⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 5-10 minutes on this exercise.
> ```

You should fill in this function below. Don't worry about division by zero or other edge cases - the goal here is just to see how the pieces of the system fit together.

*This should just be a short, one-line function.*

In [3]:
def log_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
    """Backwards function for f(x) = log(x)

    grad_out: Gradient of some loss wrt out
    out: the output of np.log(x).
    x: the input of np.log.

    Return: gradient of the given loss wrt x
    """
    # grad_out = dL/d(out)
    # out = np.log(x)
    # dL/dx = dL/d(out) * d(out)/dx = grad_out * 1/x

    return grad_out / x


tests.test_log_back(log_back)

All tests in `test_log_back` passed!


<details>
<summary>Help - I'm not sure what the output of this backward function for log should be.</summary>

By the chain rule, we have:

$$
\frac{dL}{dx} = \frac{dL}{d(\text{out})} \cdot \frac{d(\text{out})}{dx} = \frac{dL}{d(\text{out})} \cdot \frac{d(\log{x})}{dx} = \frac{dL}{d(\text{out})} \cdot \frac{1}{x}
$$

---

(Note - technically, $\frac{d(\text{out})}{dx}$ is a tensor containing the derivatives of each element of $\text{out}$ with respect to each element of $x$, and we should matrix multiply when we use the chain rule. However, since $\text{out} = \log x$ is an elementwise function of $x$, our application of the chain rule here will also be an elementwise multiplication: $\frac{dL}{dx_{ij}} = \frac{dL}{d(\text{out}_{ij})} \cdot \frac{d(\text{out}_{ij})}{dx_{ij}}$. When we get to things like matrix multiplication later, we'll have to be a bit more careful!)
</details>

<details>
<summary>Help - I get <code>ImportError: numpy.core.multiarray failed to import</code></summary>

This is an annoying colab-related error which I haven't been able to find a satisfying fix for. The setup code at the top of this notebook should have installed a version of numpy which works, although you'll have to click "Restart Runtime" (from the Runtime menu) to make this work. Make sure you don't re-run the cell with `pip install`s.

If this still doesn't work, please flag it in the Slack channel errata.
</details>


<details><summary>Solution</summary>

```python
def log_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
    """Backwards function for f(x) = log(x)

    grad_out: Gradient of some loss wrt out
    out: the output of np.log(x).
    x: the input of np.log.

    Return: gradient of the given loss wrt x
    """
    return grad_out / x
```
</details>

## Backward functions of two tensors

Now we'll implement backward functions for multiple tensors. From this point onwards, we'll be working with vector-valued derivatives, i.e. terms like $\frac{\partial x}{\partial y}$ where $y$ (and possibly $x$) are tensors. To recap what this notation means - these terms are tensors where each value is the scalar derivative of one of the elements of $x$ wrt one of the elements of $y$. For example:

- If $x$ is a scalar and $y$ is a tensor with shape `(3, 4)`, then $\frac{\partial x}{\partial y}$ is a tensor with shape `(3, 4)` where the `[i, j]`-th element is $\frac{\partial x}{\partial y[i, j]}$, i.e. the derivative of $x$ wrt a particular element of $y$.
- If $x$ is a length-5 vector and $y$ is a tensor with shape `(3, 4)` then $\frac{\partial x}{\partial y}$ is a tensor with shape `(5, 3, 4)` where the `[i, j, k]`-th element is $\frac{\partial x[i]}{\partial y[j, k]}$.

### Broadcasting

Before we go through the next exercises, we'll need to address one important topic in tensor operations - **broadcasting**. We've discussed it in earlier exercises, but we're reviewing it here because it's very important for backprop. If you're comfortable with the topic, feel free to jump forwards to the section "Why do we need broadcasting for backprop?".

Both NumPy and PyTorch have the same rules for broadcasting. When two tensors are involved in an elementwise operation, NumPy/PyTorch tries to broadcast them (i.e. copying them along dimensions) so that they both have the same shape. The rules of broadcasting are as follows:

1. You can prepend dummy dimensions (of size 1) to the start of a tensor until both have the same number of dimensions
2. After this point, if some dimension has size 1 in one of the tensors, it can be repeated until it matches the size of the corresponding dimension in the other tensor

To give a simple example - suppose we have a 2D batch of data, of shape `data.shape = (N, k)` (i.e. we have `N` separate datapoints, each being a vector of length `k`). Suppose we want to add a vector `vec` of length `k` to each datapoint. This is a valid operation, because when we try and add these two objects together:

1. `vec` gets prepended with a dummy dimension so it has shape `(1, k)` and both are 2D
2. `vec` gets repeated along the first dimension so it has shape `(N, k)`, matching the shape of `data`

Then, our output has shape `(N, k)`, and elements `output[i, j] = data[i, j] + vec[j]`.

Broadcasting can be a very easy place to make mistakes, because it's easy to lose track of the exact shape of your tensors involved. As a warm-up exercise, below are some examples of broadcasting. Can you figure out which are valid, and which will raise errors?


```python
x = np.ones((3, 1, 5))
y = np.ones((1, 4, 5))

z = x + y
```

<details>
<summary>Answer</summary>

This is valid, because the 0th dimension of `y` and the 1st dimension of `x` can both be copied so that `x` and `y` have the same shape: `(3, 4, 5)`. The resulting array `z` will also have shape `(3, 4, 5)`.

This example illustrates an important point - it's not always the case that one of the tensors is strictly smaller and the other is strictly bigger. Sometimes, both tensors will get expanded. 

</details>

```python
x = np.ones((8, 2, 6))
y = np.ones((8, 2))

z = x + y
```

<details>
<summary>Answer</summary>

This is not valid. We first need to expand `y` by appending a dimension to the front, and the last two dimensions of `x` are `(2, 6)`, which won't broadcast with `y`'s `(8, 2)`.
</details>

```python
x = np.ones((8, 2, 6))
y = np.ones((2, 6))

z = x + y
```

<details>
<summary>Answer</summary>

This is valid. Once NumPy expands `y` by appending a single dimension to the front, it can then be broadcast with `x`.
</details>

```python
x = np.ones((10, 20, 30))
y = np.ones((20, 1))

z = x + y
```

<details>
<summary>Answer</summary>

This is valid. Once NumPy expands `y` by appending a single dimension to the front, it can then be broadcast with `x` (this will involve copying along the first and last dimensions).
</details>

```python
x = np.ones((4, 1))
y = np.ones((4,))

z = x + y
```

<details>
<summary>Answer</summary>

This is valid. Numpy will expand the second one to `(1, 4)` then broadcast them both to `(4, 4)`.

Although this won't raise an error, it's very possible that this isn't what the person adding these two arrays intended. A common source of mistakes is when you add 2 tensors thinking they're the same shape, but one actually has a dummy dimension you forgot about. Sadly this is something you'll just have to be vigilant for (e.g. adding assert statements where necessary, or making sure you aren't combining too many different tensor operations in a single line), because PyTorch doesn't have built-in ways of statically checking your tensor shapes.

</details>

### Why do we need broadcasting for backprop?

Often, operations like `out = f(x, y)` involve an implicit broadcasting step. For example, if `out = x + y` with `x.shape = (4,)` and `y.shape = (3, 4)`, then really our operation has 2 steps:

1. Broadcast `x` to a broadcasted version `x_b` which has the shape of `y`, i.e. copy it along the zeroth dimension to have shape `(3, 4)`,
2. Define `out = x_b + y`, which is an operation that doesn't involve broadcasting.

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/broadcast-last-1.png" width="600">

How can we go from the gradient `dL/d(x_b)` (which might be easy to compute given it involves no broadcasting) to the gradient of `dL/dx`? The answer is that **we take `dL/d(x_b)` and sum it over the dimensions along which `x` was broadcasted to get `dL/dx`.**

For the mathematically inclined this can be proved fairly easily (we provide a sketch of the proof below). But let's focus on the intuition here. When we copy `x` along axes to create `x_b`, we're essentially creating multiple pathways for each element of `x` to affect the final loss: one path for each time the element of `x` was copied. So when we change some element of `x`, this causes a first-order change in the loss from all of these different pathways, and we have to sum these changes to get the total first-order change (i.e. the derivative). For example:

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/broadcast-last-2.png" width="540">

> ##### Summary
>
> If we know $\frac{\partial L}{\partial (\mathbf{out})}$, and want to know $\frac{\partial L}{\partial \mathbf{x}}$ (where $x$ was broadcasted to produce $\text{out}$) then there are two steps:
>
> 1. Compute $\frac{\partial L}{\partial \mathbf{x_b}}$ in the standard way, i.e. using one of your backward functions (assuming no broadcasting happened).
> 2. ***Unbroadcast*** $\frac{\partial L}{\partial \mathbf{x_b}}$, by summing it over the dimensions along which $\mathbf{x}$ was broadcasted.


<details>
<summary>Mathematical derivation (sketch)</summary>

We start with:

$$
\frac{\partial L}{\partial \mathbf{x}} = \sum_{i_1, i_2, ...} \frac{\partial L}{\partial x_b[i_1, i_2, ...]} \cdot \frac{\partial x_b[i_1, i_2, ...]}{\partial \mathbf{x}}
$$

using the chain rule. Next, we can use the fact that $\frac{\partial x_b[i_1, i_2, ...]}{\partial x[j_1, j_2, ...]}$ equals 0 everywhere except the indices where $x$ was broadcasted to produce $x_b$, where it equals 1. Therefore, this sum over indices $i_1, i_2, ...$ is equivalent to just summing $\frac{\partial L}{\partial x_b[i_1, i_2, ...]}$ over the repeated elements of $x$.

</details>

We used the term "unbroadcast" because the way that our tensor's shape changes will be the reverse of how it changed during broadcasting. If `x` was broadcasted from `(4,) -> (3, 4)`, then unbroadcasting will have to take a tensor of shape `(3, 4)` and sum over it to return a tensor of shape `(4,)`. Similarly, if `x` was broadcasted from `(1, 4) -> (3, 4)`, then we need to sum over the zeroth dimension, but leave it as a 2D tensor with a leading dummy dimension of size 1.

### Exercise - implement `unbroadcast`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 10-25 minutes on this exercise.
> This function can be quite fiddly - if you're stuck, read and understand the solution code instead.
> ```

The `unbroadcast` function below should take any value like $\frac{\partial L}{\partial x_b}$ and return $\frac{\partial L}{\partial x}$ (where the arguments `original` and `broadcasted` represent $x$ and $x_b$ respectively). In the same way that broadcasting `original` to the shape of `broadcasted` can be described as a 2-step process:

1. Extend dummy dimensions (size 1) of `original` to match the last dimensions of `broadcasted`,
2. Prepend dimensions to `original` until it has the same ndims as `broadcasted`,

the unbroadcasting process is the same process in reverse:

1. Sum over prepended dimensions of `broadcasted` until it has the same ndims as `original`,
2. Now they have the same ndims, sum over dimensions of `broadcasted` where `original` had size 1.

For example, if `original.shape = (3, 1)` and `broadcasted.shape = (1, 2, 3, 4)` then the broadcasting steps look like `(3, 1) -> (3, 4) -> (1, 2, 3, 4)` (copying along those dimensions at each step) and the unbroadcasting steps look like `(1, 2, 3, 4) -> (3, 4) -> (3, 1)` (summing over those dimensions at each step).

In [4]:
def unbroadcast(broadcasted: Arr, original: Arr) -> Arr:
    """
    Sum 'broadcasted' until it has the shape of 'original'.

    broadcasted: An array that was formerly of the same shape of 'original' and was expanded by
        broadcasting rules.
    """
    # YOUR CODE HERE: sum over `broadcasted` until it has the shape of `original`
    added_dims = len(broadcasted.shape) - len(original.shape)
    added_dims = tuple(range(added_dims))
    broadcasted = broadcasted.sum(axis=added_dims)

    for i in range(len(broadcasted.shape)):
        if broadcasted.shape[i] > 1 and original.shape[i] == 1:
            broadcasted = broadcasted.sum(axis=i, keepdims=True)

    assert broadcasted.shape == original.shape
    return broadcasted


tests.test_unbroadcast(unbroadcast)

All tests in `test_unbroadcast` passed!


<details><summary>Solution</summary>

```python
def unbroadcast(broadcasted: Arr, original: Arr) -> Arr:
    """
    Sum 'broadcasted' until it has the shape of 'original'.

    broadcasted: An array that was formerly of the same shape of 'original' and was expanded by
        broadcasting rules.
    """
    # Step 1: sum and remove prepended dims, so both arrays have same number of dims
    n_dims_to_sum = len(broadcasted.shape) - len(original.shape)
    broadcasted = broadcasted.sum(axis=tuple(range(n_dims_to_sum)))

    # Step 2: sum over dims which were originally 1 (but don't remove them)
    dims_to_sum = tuple(
        [i for i, (o, b) in enumerate(zip(original.shape, broadcasted.shape)) if o == 1 and b > 1]
    )
    broadcasted = broadcasted.sum(axis=dims_to_sum, keepdims=True)

    assert broadcasted.shape == original.shape
    return broadcasted
```
</details>

### Backward Function for Elementwise Multiply

Functions that are differentiable with respect to more than one input tensor are straightforward given that we already know how to handle broadcasting.

- We're going to have two backwards functions, one for each input argument.
- If the input arguments were broadcasted together to create a larger output, the incoming `grad_out` will be of the larger common broadcasted shape and we need to make use of `unbroadcast` from earlier to match the shape to the appropriate input argument.
- We'll want our backward function to work when one of the inputs is a float (as opposed to a tensor). We won't need to calculate the grad_in with respect to floats, so we only need to consider when y is a float for `multiply_back0` and when x is a float for `multiply_back1`.

### Exercise - implement both `multiply_back` functions

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵⚪⚪⚪
> 
> You should spend up to 10-15 minutes on these exercises.
> ```

Below, you should implement both `multiply_back0` and `multiply_back1`.

You might be wondering why we need two different functions, rather than just having a single function to serve both purposes. This will become more important later on, once we deal with functions with more than one argument, which is not symmetric in its arguments. For instance, the derivative of $x / y$ wrt $x$ is not the same as the expression you get after differentiating this wrt $y$ then swapping the labels around.

The first part of each function has been provided for you (this makes sure that both inputs are arrays, since we want to support multiplication by floats or scalars).

In [5]:
def multiply_back0(grad_out: Arr, out: Arr, x: Arr, y: Arr | float) -> Arr:
    """Backwards function for x * y wrt argument 0 aka x."""
    if not isinstance(y, Arr):
        y = np.array(y)

    return unbroadcast(grad_out * y, x) 

def multiply_back1(grad_out: Arr, out: Arr, x: Arr | float, y: Arr) -> Arr:
    """Backwards function for x * y wrt argument 1 aka y."""
    if not isinstance(x, Arr):
        x = np.array(x)

    return unbroadcast(grad_out * x, y)


tests.test_multiply_back(multiply_back0, multiply_back1)
tests.test_multiply_back_float(multiply_back0, multiply_back1)

All tests in `test_multiply_back` passed!
All tests in `test_multiply_back_float` passed!


<details>
<summary>Hint</summary>

Using the chain rule for scalar functions, we can find that:

$$
\frac{\partial L}{\partial x} = \frac{\partial L}{\partial (xy)} \cdot \frac{\partial (xy)}{\partial x} = \frac{\partial L}{\partial (xy)} \cdot y
$$

An equivalent result\* holds for vector-valued inputs:

$$
\frac{\partial L}{\partial \mathbf{x}_b} = \frac{\partial L}{\partial (\mathbf{x}_b \cdot \mathbf{y}_b)} \cdot \mathbf{y}_b
$$

where $\cdot$ is the elementwise multiplication operator, and $\mathbf{x}_b, \mathbf{y}_b$ are the broadcasted versions of the input tensors. This allows us to compute $\frac{\partial L}{\partial \mathbf{x}_b}$ based on the inputs to our `multiply_back0` function. Can you use `unbroadcast` to finish the calculation?

\*You can derive this result for yourself using the chain rule. Note that application of the chain rule here would require you to sum over the elements of the intermediate tensor $\mathbf{x}_b \cdot \mathbf{y}_b$, but we're working with elementwise operations so $\partial (\mathbf{x}_b \cdot \mathbf{y}_b)[i_1, i_2, ...] / \partial \mathbf{x}_b$ will have all elements zero except for one and the sum falls out, leaving us with the result above.

</details>

<details>
<summary>Solution (and explanation)</summary>

Based on the discussion in hint 1, we can conclude that the value $\partial L / \partial \mathbf{x}_b$ is equal to `y_b * grad_out`. Note that this is just equivalent to `y * grad_out` because `y` will be broadcasted to the shape of `y_b` when we multiply it with `grad_out` (since we know `grad_out` has the same shape as the broadcasted `x * y` tensor).

Now, we need to go from this broadcasted derivative to the shape of `x`, and we do this by applying `unbroadcast`, summing `y * grad_out` over dimensions so that it's the same shape as `x` - in other words, we return `unbroadcast(y * grad_out, x)`.

The output of `multiply_back1` is the same as `multiply_back0`, except the roles of `x` and `y` are reversed.

```python
def multiply_back0(grad_out: Arr, out: Arr, x: Arr, y: Arr | float) -> Arr:
    """Backwards function for x * y wrt argument 0 aka x."""
    if not isinstance(y, Arr):
        y = np.array(y)

    return unbroadcast(y * grad_out, x)


def multiply_back1(grad_out: Arr, out: Arr, x: Arr | float, y: Arr) -> Arr:
    """Backwards function for x * y wrt argument 1 aka y."""
    if not isinstance(x, Arr):
        x = np.array(x)

    return unbroadcast(x * grad_out, y)
```

<!-- Derivation: first we compute the gradient of the broadcasted tensor, using the chain rule:

$$
\frac{\partial L}{\partial x_b[i_1, i_2, ...]} = \frac{\partial L}{\partial (x_b \cdot y_b)[i_1, i_2, ...]} \cdot \frac{\partial (x_b \cdot y_b)[i_1, i_2, ...]}{\partial x_b[i_1, i_2, ...]}
$$

where we've avoided the sum over indices of $x_b \cdot y_b$ because this is an elementwise operation, and so those gradients must be zero everywhere except where all $i_k = j_k$. We can further write this as:

$$
\frac{\partial L}{\partial x_b[i_1, i_2, ...]} = \text{grad\_out}[i_1, i_2, ...] \cdot y[i_1, i_2, ...]
$$

which gives us the result for broadcasted tensors! Note that although this tensor is `grad_out * y_b`, we can compute it as `grad_out * y` because we know that `y` is broadcastable to the shape of `grad_out` (since this is also the same as the shape of `out`).

Once we have this result `grad_out * y`, we just use `unbroadcast` to sum over the dimensions `x` was broadcasted along.  -->

</details>

Now we'll use our backward functions to do backpropagation manually, for the following computational graph:

<img src="https://raw.githubusercontent.com/callummcdougall/Fundamentals/main/images/abcdefg.png" width=550>

### Exercise - implement `forward_and_back`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵🔵🔵⚪
> 
> You should spend up to 15-20 minutes on these exercises.
> This function is very useful for getting a sense for what backprop looks like in practice.
> ```

Below, you should implement the `forward_and_back` function. This is an opportunity for you to practice using the backward functions you've written so far, and should hopefully give you a better sense of how the full backprop function will eventually work.

Note - you might be wondering what `grad_out` should be the first time you call a backward function. We've so far assumed that the final node `L` in the computational graph is a scalar, and all gradients (i.e. the `grad_out` input to our backward functions and the output of our backward functions) are gradients of `L`. In this case, it would make sense to set `grad_out = [1]` for the first backward function call, since we're computing $\frac{\partial L}{\partial L} = 1$. But what about cases like the one above, when our final node `g` might not be a scalar?

The answer is that we effectively add a final scalar node into our graph, defined as `L = (g * v).sum()` for some tensor `v` of the same shape as `g`. This is called the **directional derivative**. We then compute all tensor gradients using `L` as our final node. For example, this means that in our first gradient calculation (where `g` is the output and `f` is the input) we'll use `grad_out = dL/dg = v`. You don't really need to worry about this, because most of the time (including in the exercise below) we'll just be using the default behaviour where `v` is a tensor of 1s - this is equivalent to `L = g.sum()`.

In [6]:
def forward_and_back(a: Arr, b: Arr, c: Arr) -> tuple[Arr, Arr, Arr]:
    """
    Calculates the output of the computational graph above (g), then backpropogates the gradients
    and returns dg/da, dg/db, and dg/dc.
    """
    d = a * b
    e = np.log(c)
    f = d * e
    g = np.log(f)
    final_grad_out = np.ones_like(g)

    # YOUR CODE HERE - use your backward functions to compute the gradients of g wrt a, b, and c

    """
    g = log(f)
    dg/df = 1/f
    
    f = d * e
    dg_dd = dg_df * df_dd
    """
    dg_df = log_back(final_grad_out, g, f)

    dg_dd = multiply_back0(dg_df, f, d, e)
    dg_de = multiply_back1(dg_df, f, d, e)
    
    dg_da = multiply_back0(dg_dd, d, a, b)
    dg_db = multiply_back1(dg_dd, d, a, b)

    dg_dc = log_back(dg_de, e, c)

    return (dg_da, dg_db, dg_dc)


tests.test_forward_and_back(forward_and_back)

All tests in `test_forward_and_back` passed!


<details><summary>Solution</summary>

```python
def forward_and_back(a: Arr, b: Arr, c: Arr) -> tuple[Arr, Arr, Arr]:
    """
    Calculates the output of the computational graph above (g), then backpropogates the gradients
    and returns dg/da, dg/db, and dg/dc.
    """
    d = a * b
    e = np.log(c)
    f = d * e
    g = np.log(f)
    final_grad_out = np.ones_like(g)

    dg_df = log_back(grad_out=final_grad_out, out=g, x=f)
    dg_dd = multiply_back0(dg_df, f, d, e)
    dg_de = multiply_back1(dg_df, f, d, e)
    dg_da = multiply_back0(dg_dd, d, a, b)
    dg_db = multiply_back1(dg_dd, d, a, b)
    dg_dc = log_back(dg_de, e, c)

    return (dg_da, dg_db, dg_dc)
```
</details>

In the next section, you'll build up to full automation of this backpropagation process, in a way that's similar to PyTorch's `autograd`.

# 2️⃣ Autograd

> ##### Learning Objectives
>
> * Perform a topological sort of a computational graph (and understand why this is important).
> * Implement a the `backprop` function, to calculate and store gradients for all tensors in a computational graph.

Now, rather than figuring out which backward functions to call, in what order, and what their inputs should be, we'll write code that takes care of that for us. We'll implement this with a few major components:

- `Tensor`, which is a wrapper around numpy arrays which is equivalent to PyTorch's `Tensor` class
- `Recipe`, which tracks the extra information needed to run backpropagation (mainly how this tensor was created from other tensors)
- `wrap_forward_fn`, which takes a numpy function mapping arrays to arrays (e.g. `np.log`) and returns a new function that maps tensors to tensors (while also creating the recipe for the new tensor)

## Wrapping Arrays (Tensor)

We're going to wrap each array with a wrapper object from our library which we'll call `Tensor` because it's going to behave similarly to a `torch.Tensor`.

Each Tensor that is created by one of our forward functions will have a `Recipe`, which tracks the extra information need to run backpropagation.

`wrap_forward_fn` will take a forward function and return a new forward function that does the same thing while recording the info we need to do backprop in the `Recipe`.

## Recipe

Let's start by taking a look at `Recipe`.

`@dataclass` is a handy class decorator that sets up an `__init__` function for the class that takes the provided attributes as arguments and sets them as you'd expect.

The class `Recipe` is designed to track the forward functions in our computational graph, so that gradients can be calculated during backprop. Each tensor created by a forward function has its own `Recipe`. We're naming it this because it is a set of instructions that tell us which ingredients went into making our tensor: what the function was, and what tensors were used as input to the function to produce this one as output.

In [7]:
@dataclass(frozen=True)
class Recipe:
    """Extra information necessary to run backpropagation. You don't need to modify this."""

    func: Callable
    "The 'inner' NumPy function that does the actual forward computation."
    "Note, we call it 'inner' to distinguish it from the wrapper we'll create for it later on."

    args: tuple
    "The input arguments passed to func."
    "For instance, if func=np.sum then args would be a length-1 tuple with the tensor to be summed."

    kwargs: dict[str, Any]
    "Keyword arguments passed to func."
    "For instance, if func was np.sum then kwargs might contain 'dim' and 'keepdims'."

    parents: dict[int, "Tensor"]
    "Map from positional argument index to the Tensor at that position."
    "For passing gradients back along the computational graph."

Note that `args` just stores the values of the underlying arrays, but `parents` stores the actual tensors. This is because they serve two different purposes: `args` is required for computing the value of gradients during backpropagation, and `parents` is required to infer the structure of the computational graph (i.e. which tensors were used to produce which other tensors).

Here are some examples, to build intuition for what the four fields of `Recipe` are, and why we need all four of them to fully describe a tensor in our graph and how it was created. Make sure you understand each of these examples before moving on, because it'll really help you progress quickly through the following exercises.

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/recipe-v2.png" width="800">

## Registering backwards functions

The `Recipe` takes care of tracking the forward functions in our computational graph, but we still need a way to find the backward function corresponding to a given forward function when we do backprop (or possibly the set of backward functions, if the forward function takes more than one argument).

### Exercise - implement `BackwardFuncLookup`

> ```yaml
> Difficulty: 🔴🔴⚪⚪⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 10-15 minutes on these exercises.
> These exercises should be very short, once you understand what is being asked.
> ```

We will define a class `BackwardFuncLookup` in order to find the backward function for a given forward function. The implementation details are left up to you - all that matters is that you pass the test code in the cell below. Reading this test code should explain how the `BackwardFuncLookup` class needs to be used - for any given forward function e.g. `np.log`, we need to be able to add a set of backward functions for each of its positional arguments.

In [8]:
class BackwardFuncLookup:
    def __init__(self) -> None:
        self.data = {}

    def add_back_func(self, forward_fn: Callable, arg_position: int, back_fn: Callable) -> None:
        self.data[(forward_fn, arg_position)] = back_fn

    def get_back_func(self, forward_fn: Callable, arg_position: int) -> Callable:
        return self.data[(forward_fn, arg_position)]


BACK_FUNCS = BackwardFuncLookup()

BACK_FUNCS.add_back_func(np.log, 0, log_back)
BACK_FUNCS.add_back_func(np.multiply, 0, multiply_back0)
BACK_FUNCS.add_back_func(np.multiply, 1, multiply_back1)

assert BACK_FUNCS.get_back_func(np.log, 0) == log_back
assert BACK_FUNCS.get_back_func(np.multiply, 0) == multiply_back0
assert BACK_FUNCS.get_back_func(np.multiply, 1) == multiply_back1

print("Tests passed - BackwardFuncLookup class is working as expected!")

Tests passed - BackwardFuncLookup class is working as expected!


<details>
<summary>Help - I'm stuck on this implementation</summary>

You can define a dict like `self.back_funcs` in the `__init__` method. When you add / retrieve a function, you can use the tuple `(forward_fn, arg_position)` as a key, and the backward function as the value.

</details>


<details><summary>Solution</summary>

```python
class BackwardFuncLookup:
    def __init__(self) -> None:
        self.back_funcs = {}  # each entry is a tuple of (forward_fn, arg_position) -> back_fn

    def add_back_func(self, forward_fn: Callable, arg_position: int, back_fn: Callable) -> None:
        self.back_funcs[(forward_fn, arg_position)] = back_fn

    def get_back_func(self, forward_fn: Callable, arg_position: int) -> Callable:
        return self.back_funcs[(forward_fn, arg_position)]
```

</details>

## Tensors

Our Tensor object has these fields:

- An `array` field of type `np.ndarray`. These are the actual tensor values.
- A `requires_grad` field of type `bool`. This determines whether we need to compute gradients for this tensor (note this doesn't necessarily mean we need to store them, see below).
- A `grad` field of the same size and type as the value. This is where gradients are stored.
- A `recipe` field, as we've already seen. A tensor has a recipe if and only if it was created by some operation on other tensors.

### `requires_grad` and `is_leaf`

The meaning of `requires_grad` is that when doing operations using this tensor, the recipe will be stored and it and any descendents will be included in the computational graph. Note that `requires_grad` does not necessarily mean that we will save the accumulated gradients to this tensor's `.grad` parameter when doing backprop - for example we require gradients to propagate through the hidden activations of a neural network to get back to grads for our model weights, but we don't need to actually store the gradients of the hidden activations.

We use `is_leaf` to differentiate between these cases (see the method `Tensor.is_leaf` defined below) - a leaf tensor is one that represents the end of a backprop path, either because it doesn't require gradients or it has no nodes further back in the computational graph which require gradients. So our backprop algorithm will always terminate at a leaf node, then store that leaf node's gradient as `.grad` only if `requires_grad` is true.

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/is_leaf.png" width="900">

You can investigate this by running the following code:

```python
layer = torch.nn.Linear(3, 4)
input = torch.ones(3)
output = layer(input)

print(layer.weight.is_leaf)       # -> True
print(layer.weight.requires_grad) # -> True

print(output.is_leaf)             # -> False
print(output.requires_grad)       # -> True

print(input.is_leaf)              # -> True
print(input.requires_grad)        # -> False
```

When creating tensors, we can set `requires_grad` explicitly (e.g. it's false by default for most tensors, but is true by default if that tensor is wrapped in `torch.nn.Parameter` - we'll create our own version of this later). When creating a tensor from another tensor or tensors, `requires_grad` is true if and only if all of the following 3 conditions hold:

1. Global grad tracking is enabled. In this notebook we've represented this with the global variable `grad_tracking_enabled`, but in PyTorch this is done with `torch.set_grad_enabled(False)`. This is useful because when you're looking at a model in inference mode, gradient tracking can waste memory and it's useful to disable it (we'll do this a lot next week, when we study transformer interpretability).
2. At least one of the input tensors requires grad (since this is equivalent to "there are other tensors further upstream which we need to get gradients for").
3. The function is differentiable (if not, obviously we can't compute gradients).

Now, we're giving you the full `Tensor` class. Most of these methods are currently undefined, and you'll go on to define them in later exercises (so you won't need to write any code this class). For now, just pay attention to the docstring & `__init__` methods.

In [9]:
Arr = np.ndarray


class Tensor:
    """
    A drop-in replacement for torch.Tensor supporting a subset of features.
    """

    array: Arr
    "The underlying array. Can be shared between multiple Tensors."
    requires_grad: bool
    "If True, calling functions or methods on this tensor will track relevant data for backprop."
    grad: "Tensor | None"
    "Backpropagation will accumulate gradients into this field."
    recipe: "Recipe | None"
    "Extra information necessary to run backpropagation."

    def __init__(self, array: Arr | list, requires_grad=False):
        self.array = array if isinstance(array, Arr) else np.array(array)
        if self.array.dtype == np.float64:
            self.array = self.array.astype(np.float32)
        self.requires_grad = requires_grad
        self.grad = None
        self.recipe = None
        "If not None, this tensor's array was created as recipe.func(*recipe.args, **recipe.kwargs)."

    def __neg__(self) -> "Tensor":
        return negative(self)

    def __add__(self, other) -> "Tensor":
        return add(self, other)

    def __radd__(self, other) -> "Tensor":
        return add(other, self)

    def __sub__(self, other) -> "Tensor":
        return subtract(self, other)

    def __rsub__(self, other) -> "Tensor":
        return subtract(other, self)

    def __mul__(self, other) -> "Tensor":
        return multiply(self, other)

    def __rmul__(self, other):
        return multiply(other, self)

    def __truediv__(self, other):
        return true_divide(self, other)

    def __rtruediv__(self, other):
        return true_divide(other, self)

    def __matmul__(self, other):
        return matmul(self, other)

    def __rmatmul__(self, other):
        return matmul(other, self)

    def __eq__(self, other):
        return eq(self, other)

    def __repr__(self) -> str:
        return f"Tensor({repr(self.array)}, requires_grad={self.requires_grad})"

    def __len__(self) -> int:
        if self.array.ndim == 0:
            raise TypeError
        return self.array.shape[0]

    def __hash__(self) -> int:
        return id(self)

    def __getitem__(self, index) -> "Tensor":
        print(self, index)
        return getitem(self, index)

    def add_(self, other: "Tensor", alpha: float = 1.0) -> "Tensor":
        add_(self, other, alpha=alpha)
        return self

    def sub_(self, other: "Tensor", alpha: float = 1.0) -> "Tensor":
        sub_(self, other, alpha=alpha)
        return self

    def __iadd__(self, other: "Tensor") -> "Tensor":
        self.add_(other)
        return self

    def __isub__(self, other: "Tensor") -> "Tensor":
        self.sub_(other)
        return self

    @property
    def T(self) -> "Tensor":
        return permute(self, axes=(-1, -2))

    def item(self):
        return self.array.item()

    def sum(self, dim=None, keepdim=False) -> "Tensor":
        return sum(self, dim=dim, keepdim=keepdim)

    def log(self) -> "Tensor":
        return log(self)

    def exp(self) -> "Tensor":
        return exp(self)

    def reshape(self, new_shape) -> "Tensor":
        return reshape(self, new_shape)

    def permute(self, dims) -> "Tensor":
        return permute(self, dims)

    def maximum(self, other) -> "Tensor":
        return maximum(self, other)

    def relu(self) -> "Tensor":
        return relu(self)

    def argmax(self, dim=None, keepdim=False) -> "Tensor":
        return argmax(self, dim=dim, keepdim=keepdim)

    def uniform_(self, low: float, high: float) -> "Tensor":
        self.array[:] = np.random.uniform(low, high, self.array.shape)
        return self

    def backward(self, end_grad: "Arr | Tensor | None" = None) -> None:
        if isinstance(end_grad, Arr):
            end_grad = Tensor(end_grad)
        return backprop(self, end_grad)

    def size(self, dim: int | None = None):
        if dim is None:
            return self.shape
        return self.shape[dim]

    @property
    def shape(self):
        return self.array.shape

    @property
    def ndim(self):
        return self.array.ndim

    @property
    def is_leaf(self):
        """Same as https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf.html"""
        if self.requires_grad and self.recipe and self.recipe.parents:
            return False
        return True

    def __bool__(self):
        if np.array(self.shape).prod() != 1:
            raise RuntimeError("bool value of Tensor with more than one value is ambiguous")
        return bool(self.item())


def empty(*shape: int) -> Tensor:
    """Like torch.empty."""
    return Tensor(np.empty(shape))


def zeros(*shape: int) -> Tensor:
    """Like torch.zeros."""
    return Tensor(np.zeros(shape))


def arange(start: int, end: int, step=1) -> Tensor:
    """Like torch.arange(start, end)."""
    return Tensor(np.arange(start, end, step=step))


def tensor(array: Arr, requires_grad=False) -> Tensor:
    """Like torch.tensor."""
    return Tensor(array, requires_grad=requires_grad)

## Forward Pass: Building the Computational Graph

Let's start with a simple case: our `log` function. `log_forward` is a wrapper, which should implement the functionality of `np.log` but work with tensors rather than arrays.

### Exercise - implement `log_forward`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 15-20 minutes on this exercise.
> ```

Your `log` function should be a wrapper around `np.log`, which takes and returns a `Tensor` object rather than numpy arrays. You can refer to the first of the five diagrams at the start of the "Recipe" section if you're stuck.

Some more hints / tips:

- As a reminder, `requires_grad` is true if both global gradient tracking is enabled (i.e. `grad_tracking_enabled` is true) and at least one of the inputs has `requires_grad` true.
- You should also set the recipe for the new tensor, if `requires_grad` is true (if not then you can just set the recipe to None).

Later we'll write code to wrap numpy functions in a generic and reusable way, but for now we just want to get this working for `np.log`.

In [10]:
def log_forward(x: Tensor) -> Tensor:
    """Performs np.log on a Tensor object."""
    arr = np.log(x.array)
    req_grad = (grad_tracking_enabled and x.requires_grad)
    ret = Tensor(arr, req_grad)

    if req_grad:
        ret.recipe = Recipe(np.log, (x.array,), {}, {0: x})
    return ret


log = log_forward
tests.test_log(Tensor, log_forward)
tests.test_log_no_grad(Tensor, log_forward)
a = Tensor([1], requires_grad=True)
grad_tracking_enabled = False
b = log_forward(a)
grad_tracking_enabled = True
assert not b.requires_grad, "should not require grad if grad tracking globally disabled"
assert b.recipe is None, "should not create recipe if grad tracking globally disabled"

All tests in `test_log` passed!
All tests in `test_log_no_grad` passed!


<details><summary>Solution</summary>

```python
def log_forward(x: Tensor) -> Tensor:
    """Performs np.log on a Tensor object."""
    # Get the function output (as a numpy array)
    array = np.log(x.array)

    # Find whether the tensor requires grad
    requires_grad = grad_tracking_enabled and x.requires_grad

    # Create the tensor
    out = Tensor(array, requires_grad)

    # Set the recipe (if we need it)
    if requires_grad:
        out.recipe = Recipe(func=np.log, args=(x.array,), kwargs={}, parents={0: x})

    return out
```
</details>

Now let's do the same for multiply, to see how to handle functions with multiple arguments.

### Exercise - implement `multiply_forward`

> ```yaml
> Difficulty: 🔴🔴🔴🔴⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 15-20 minutes on this exercise.
> ```

There are a few differences between this and log:

- The actual function to be called is different
- We need more than one argument in `args` and `parents`, when defining `Recipe`
- `requires_grad` should be true if `grad_tracking_enabled=True`, and ANY of the input tensors require grad
- One of the inputs may be an int, so you'll need to deal with this case before calculating `out`

If you're confused, you can scroll up to the diagram at the top of the page (which tells you how to construct the recipe for functions like multiply or add when they are both arrays, or when one is an array and the other is a scalar).

In [11]:
def multiply_forward(a: Tensor | int, b: Tensor | int) -> Tensor:
    """Performs np.multiply on a Tensor object."""
    assert isinstance(a, Tensor) or isinstance(b, Tensor)

    # Get all function arguments as non-tensors (i.e. either ints or arrays)
    arg_a = a.array if isinstance(a, Tensor) else a
    arg_b = b.array if isinstance(b, Tensor) else b

    array = arg_a * arg_b
    requires_grad = (a.requires_grad if isinstance(a, Tensor) else False)
    requires_grad = requires_grad or (b.requires_grad if isinstance(b, Tensor) else False)
    requires_grad = requires_grad and grad_tracking_enabled
    
    ret = Tensor(array, requires_grad)

    if not requires_grad:
        return ret
    
    parents = {}

    if isinstance(a, Tensor):
        parents[0] = a
    
    if isinstance(b, Tensor):
        parents[1] = b

    ret.recipe = Recipe(
        np.multiply,
        (arg_a, arg_b),
        {},
        parents
    )

    return ret


multiply = multiply_forward
tests.test_multiply(Tensor, multiply_forward)
tests.test_multiply_no_grad(Tensor, multiply_forward)
tests.test_multiply_float(Tensor, multiply_forward)
a = Tensor([2], requires_grad=True)
b = Tensor([3], requires_grad=True)
grad_tracking_enabled = False
b = multiply_forward(a, b)
grad_tracking_enabled = True
assert not b.requires_grad, "should not require grad if grad tracking globally disabled"
assert b.recipe is None, "should not create recipe if grad tracking globally disabled"

All tests in `test_multiply` passed!
All tests in `test_multiply_no_grad` passed!
All tests in `test_multiply_float` passed!


<details>
<summary>Help - I get <code>AttributeError: 'int' object has no attribute 'array'</code>.</summary>

Remember that your multiply function should also accept integers. You need to separately deal with the cases where `a` and `b` are integers or Tensors.
</details>

<details>
<summary>Help - I get <code>AssertionError: assert len(c.recipe.parents) == 1 and c.recipe.parents[0] is a</code> in the "test_multiply_float" test.</summary>

This is probably because you've stored the inputs to `multiply` as integers when one of the is an integer. Remember, `parents` should just be a list of the **Tensors** that were inputs to `multiply`, so you shouldn't add ints.
</details>


<details><summary>Solution</summary>

```python
def multiply_forward(a: Tensor | int, b: Tensor | int) -> Tensor:
    """Performs np.multiply on a Tensor object."""
    assert isinstance(a, Tensor) or isinstance(b, Tensor)

    # Get all function arguments as non-tensors (i.e. either ints or arrays)
    arg_a = a.array if isinstance(a, Tensor) else a
    arg_b = b.array if isinstance(b, Tensor) else b

    # Calculate the output (which is a numpy array)
    out_arr = arg_a * arg_b
    assert isinstance(out_arr, np.ndarray)

    # Find whether the tensor requires grad (need to check if ANY of the inputs do)
    requires_grad = grad_tracking_enabled and any(
        [isinstance(x, Tensor) and x.requires_grad for x in (a, b)]
    )

    # Create the output tensor from the underlying data and the requires_grad flag
    out = Tensor(out_arr, requires_grad)

    # If requires_grad, then create a recipe
    if requires_grad:
        parents = {idx: arr for idx, arr in enumerate([a, b]) if isinstance(arr, Tensor)}
        out.recipe = Recipe(np.multiply, (arg_a, arg_b), {}, parents)

    return out
```
</details>

## Forward Pass - Generic Version

All our forward functions are going to look extremely similar to `log_forward` and `multiply_forward`.
Implement the higher order function `wrap_forward_fn` that takes a `Arr -> Arr` function and returns a `Tensor -> Tensor` function. In other words, `wrap_forward_fn(np.multiply)` should evaluate to a callable that does the same thing as your `multiply_forward` (and same for `np.log`).

### Exercise - implement `wrap_forward_fn`

> ```yaml
> Difficulty: 🔴🔴🔴🔴⚪
> Importance: 🔵🔵🔵🔵⚪
> 
> You should spend up to 20-25 minutes on this exercise.
> This exercise is probably the 2nd most conceptually important today, after the backprop implementation at the end of the section.
> ```

If you're stuck, you can start with the same structure as the wrapped multiply function above (i.e. just copy and paste the code from solutions and use this as a stand in for `tensor_func` below, then modify it).

In [12]:
def wrap_forward_fn(numpy_func: Callable, is_differentiable=True) -> Callable:
    """
    Args:
        numpy_func:
            takes any number of positional arguments, some of which may be NumPy arrays, and any
            number of keyword arguments which we aren't allowing to be NumPy arrays at present. It
            returns a single NumPy array.

        is_differentiable:
            if True, numpy_func is differentiable with respect to some input argument, so we may
            need to track information in a Recipe. If False, we definitely don't need to track
            information.

    Returns:
        tensor_func
            It has the same signature as numpy_func, except it operates on Tensors instead of Arr.
    """

    def tensor_func(*args: Any, **kwargs: Any) -> Tensor:
        # Get all function arguments as non-tensors (i.e. either ints or arrays)
        arg_arrays = tuple([(a.array if isinstance(a, Tensor) else a) for a in args])

        # YOUR CODE HERE - create output array & make it a tensor with requires_grad (& recipe)
        
        # print(arg_arrays, kwargs)
        out_array = numpy_func(*arg_arrays, **kwargs)
        # print(out_array)
        if not grad_tracking_enabled or not is_differentiable:
            return Tensor(out_array, False)
        
        requires_grad = False
        parents = {}
        for i, arg in enumerate(args):
            if not isinstance(arg, Tensor):
                continue

            parents[i] = arg

            if arg.requires_grad:
                requires_grad = True

        if not requires_grad:
            return Tensor(out_array, False)
        
        out = Tensor(out_array, True)

        out.recipe = Recipe(
            numpy_func,
            arg_arrays,
            kwargs,
            parents
        )

        return out

    return tensor_func


def _sum(x: Arr, dim=None, keepdim=False) -> Arr:
    # need to be careful with sum, because kwargs have different names in torch and numpy
    return np.sum(x, axis=dim, keepdims=keepdim)


log = wrap_forward_fn(np.log)
multiply = wrap_forward_fn(np.multiply)
eq = wrap_forward_fn(np.equal, is_differentiable=False)
sum = wrap_forward_fn(_sum)

tests.test_log(Tensor, log)
tests.test_log_no_grad(Tensor, log)
tests.test_multiply(Tensor, multiply)
tests.test_multiply_no_grad(Tensor, multiply)
tests.test_multiply_float(Tensor, multiply)
tests.test_eq(Tensor, eq)
tests.test_sum(Tensor)

All tests in `test_log` passed!
All tests in `test_log_no_grad` passed!
All tests in `test_multiply` passed!
All tests in `test_multiply_no_grad` passed!
All tests in `test_multiply_float` passed!
All tests in `test_eq` passed!
All tests in `test_sum` passed!


<details>
<summary>Help - I'm getting <code>NameError: name 'getitem' is not defined</code>.</summary>

This is probably because you're calling `numpy_func` on the args themselves. Recall that `args` will be a list of `Tensor` objects, and that you should call `numpy_func` on the underlying arrays.
</details>

<details>
<summary>Help - I'm getting an AssertionError on <code>assert c.requires_grad == True</code> (or something similar).</summary>

This is probably because you're not defining `requires_grad` correctly. Remember that the output of a forward function should have `requires_grad = True` if and only if all of the following hold:

* Grad tracking is enabled
* The function is differentiable
* **Any** of the inputs are tensors with `requires_grad = True`
</details>

<details>
<summary>Help - my function passes all tests up to <code>test_sum</code>, but then fails here.</summary>

`test_sum`, unlike the previous tests, wraps a function that uses keyword arguments. So if you're failing here, it's probably because you didn't use `kwargs` correctly.

`kwargs` should be used in two ways: once when actually calling the `numpy_func`, and once when defining the `Recipe` object for the output tensor.
</details>


<details><summary>Solution</summary>

```python
def wrap_forward_fn(numpy_func: Callable, is_differentiable=True) -> Callable:
    """
    Args:
        numpy_func:
            takes any number of positional arguments, some of which may be NumPy arrays, and any
            number of keyword arguments which we aren't allowing to be NumPy arrays at present. It
            returns a single NumPy array.

        is_differentiable:
            if True, numpy_func is differentiable with respect to some input argument, so we may
            need to track information in a Recipe. If False, we definitely don't need to track
            information.

    Returns:
        tensor_func
            It has the same signature as numpy_func, except it operates on Tensors instead of Arr.
    """

    def tensor_func(*args: Any, **kwargs: Any) -> Tensor:
        # Get all function arguments as non-tensors (i.e. either ints or arrays)
        arg_arrays = tuple([(a.array if isinstance(a, Tensor) else a) for a in args])

        # Calculate the output (which is a numpy array)
        out_arr = numpy_func(*arg_arrays, **kwargs)

        # Find whether the tensor requires grad (need to check if ANY of the inputs do)
        requires_grad = (
            grad_tracking_enabled
            and is_differentiable
            and any([(isinstance(a, Tensor) and a.requires_grad) for a in args])
        )

        # Create the output tensor from the underlying data and the requires_grad flag
        out = Tensor(out_arr, requires_grad)

        # If requires_grad, then create a recipe
        if requires_grad:
            parents = {idx: a for idx, a in enumerate(args) if isinstance(a, Tensor)}
            out.recipe = Recipe(numpy_func, arg_arrays, kwargs, parents)

        return out

    return tensor_func
```
</details>

Note - none of these functions involve keyword args, so the tests won't detect if you're handling kwargs incorrectly (or even failing to use them at all). If your code fails in later exercises, you might want to come back here and check that you're using the kwargs correctly. Alternatively, once you pass the tests, you can compare your code to the solutions and see how they handle kwargs.

## Backpropagation

Now all the pieces are in place to implement backpropagation. We need to loop over our nodes from right to left (i.e. starting with the tensors computed last and moving backwards chronologically). At each node, we:

- Call the backward function to transform the grad wrt output to the grad wrt input.
- If the node is a leaf, write the grad to the grad field.
- Otherwise, accumulate the grad into temporary storage.

### Topological Sort

As part of backprop, we need to sort the nodes of our graph so we can traverse the graph in the appropriate order.

### Exercise - implement `topological_sort`

> ```yaml
> Difficulty: 🔴🔴🔴🔴⚪
> Importance: 🔵⚪⚪⚪⚪
> 
> You should spend up to 20-25 minutes on this exercise.
> Note, it's completely fine to skip this problem if you're not very interested in it.
> It's more of a fun LeetCode-style challenge, and writing a solution for it isn't crucial for understanding today's content.
> ```

Write a general function `topological_sort` that return a list of node's children in topological order (beginning with the furthest descendants, ending with the starting node) using [depth-first search](https://en.wikipedia.org/wiki/Topological_sorting).

We've given you a `Node` class, with a `children` attribute, and a `get_children` function. You shouldn't change any of these, and your `topological_sort` function should use `get_children` to access a node's children rather than calling `node.children` directly. In subsequent exercises, we'll replace the `Node` class with the `Tensor` class (and using a different `get_children` function), so this will ensure your code still works for this new case.

If you're stuck, try looking at the pseudocode from some of [these examples](https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm).

In [13]:
class Node:
    def __init__(self, *children):
        self.children = list(children)


def get_children(node: Node) -> list[Node]:
    return node.children


def topological_sort(node: Node, get_children: Callable) -> list[Node]:
    """
    Return a list of node's descendants in reverse topological order from future
    to past (i.e. `node` should be last).

    Should raise an error if the graph with `node` as root is not in fact acyclic.
    """
    result: list[
        Node
    ] = []  # stores the list of nodes to be returned (in reverse topological order)
    perm: set[Node] = set()  # same as `result`, but as a set (faster to check for membership)
    temp: set[Node] = set()  # keeps track of previously visited nodes (to detect cyclicity)

    def visit(cur: Node):
        """
        Recursive function which visits all the children of the current node,
        and appends them all to `result` in the order they were found.
        """
        if cur in perm:
            return
        if cur in temp:
            raise ValueError("Not a DAG!")
        temp.add(cur)

        for next in get_children(cur):
            visit(next)

        result.append(cur)
        perm.add(cur)
        temp.remove(cur)

    visit(node)
    return result


tests.test_topological_sort_linked_list(topological_sort)
tests.test_topological_sort_branching(topological_sort)
tests.test_topological_sort_rejoining(topological_sort)
tests.test_topological_sort_cyclic(topological_sort)

All tests in `test_topological_sort_linked_list` passed!
All tests in `test_topological_sort_branching` passed!
All tests in `test_topological_sort_rejoining` passed!
All tests in `test_topological_sort_cyclic` passed!


<details>
<summary>Help - my function is hanging without returning any values.</summary>

This is probably because it's going around in cycles when fed a cyclic graph. You should add a way of raising an error if your function detects that the graph isn't cyclic. One way to do this is to create a set `temp`, which stores the nodes you've visited on a particular excursion into the graph, then you can raise an error if you come across an already visited node.
</details>

<details>
<summary>Help - I'm completely stuck on how to implement this, and would like the template for some code.</summary>

Here is the template for a depth-first search implementation:

```python
def topological_sort(node: Node, get_children: Callable) -> list[Node]:
    
    result: list[Node] = [] # stores the list of nodes to be returned (in reverse topological order)
    perm: set[Node] = set() # same as `result`, but as a set (faster to check for membership)
    temp: set[Node] = set() # keeps track of previously visited nodes (to detect cyclicity)

    def visit(cur: Node):
        '''
        Recursive function which visits all the children of the current node, and appends them all
        to `result` in the order they were found.
        '''
        pass

    visit(node)
    return result
```
</details>


<details><summary>Solution</summary>

```python
def topological_sort(node: Node, get_children: Callable) -> list[Node]:
    """
    Return a list of node's descendants in reverse topological order from future
    to past (i.e. `node` should be last).

    Should raise an error if the graph with `node` as root is not in fact acyclic.
    """
    result: list[
        Node
    ] = []  # stores the list of nodes to be returned (in reverse topological order)
    perm: set[Node] = set()  # same as `result`, but as a set (faster to check for membership)
    temp: set[Node] = set()  # keeps track of previously visited nodes (to detect cyclicity)

    def visit(cur: Node):
        """
        Recursive function which visits all the children of the current node,
        and appends them all to `result` in the order they were found.
        """
        if cur in perm:
            return
        if cur in temp:
            raise ValueError("Not a DAG!")
        temp.add(cur)

        for next in get_children(cur):
            visit(next)

        result.append(cur)
        perm.add(cur)
        temp.remove(cur)

    visit(node)
    return result
```
</details>

Now, we've given you the function `sorted_computational_graph`. This calls `topological_sort` and returns the result in reverse order (because we want to start with the root node). The "get children" function we're using here is "return all tensors in the recipe for this tensor".

<img src="https://github.com/callummcdougall/Fundamentals/blob/main/images/abcdefg.png?raw=true" width=500>

In [14]:
def sorted_computational_graph(tensor: Tensor) -> list[Tensor]:
    """
    For a given tensor, return a list of Tensors that make up the nodes of the given Tensor's
    computational graph, in reverse topological order (i.e. `tensor` should be first).
    """

    def get_parents(tensor: Tensor) -> list[Tensor]:
        if tensor.recipe is None:
            return []
        return list(tensor.recipe.parents.values())

    return topological_sort(tensor, get_parents)[::-1]


a = Tensor([1], requires_grad=True)
b = Tensor([2], requires_grad=True)
c = Tensor([3], requires_grad=True)
d = a * b
e = c.log()
f = d * e
g = f.log()
name_lookup = {a: "a", b: "b", c: "c", d: "d", e: "e", f: "f", g: "g"}

print([name_lookup[t] for t in sorted_computational_graph(g)])

['g', 'f', 'e', 'c', 'd', 'b', 'a']


Compare your output with the computational graph. You should never be printing `x` before `y` if there is an edge `x --> ... --> y` (this should result in approximately reverse alphabetical order).

### The `backward` method

Now we're really ready for backprop!

Recall that in the implementation of the class `Tensor`, we had:

```python
class Tensor:
    ...
    def backward(self, end_grad: "Arr | Tensor | None" = None) -> None:
        if isinstance(end_grad, Arr):
            end_grad = Tensor(end_grad)
        return backprop(self, end_grad)
```

In other words, for a tensor `out`, calling `out.backward()` is equivalent to `backprop(out)`.

Recall that in the last section, we said that calling `backward` on a scalar tensor is equivalent to backpropagating on the weighted sum of all the elements of the tensor, i.e. `L = (tensor * v).sum()`. By default `v` is a tensor of 1s of the same shape as the tensor you're calling `backward` from, meaning we're just backpropagating on `L.sum()`. Here, the `end_grad` argument you pass to `backward` gives you the option to override this default behaviour, in other words if it's supplied you should use it as the first input to your backward function instead of a tensor of 1s. The use case for this is pretty niche (used for things like **influence functions**), but it's still useful to understand!

### Exercise - implement `backprop`

> ```yaml
> Difficulty: 🔴🔴🔴🔴🔴
> Importance: 🔵🔵🔵🔵🔵
> 
> You should spend up to 30-45 minutes on this exercise.
> 
> This exercise is the most conceptually important today, and probably the hardest. We've provided several dropdowns to help you.
> ```

Now, we get to the actual backprop function! Some code is provided below, which you should complete.

If you want a challenge, you can try and implement it straight away, with out any help. However, because this is quite a challenging exercise, you can also use the dropdowns below. The first one gives you a sketch of the backpropagation algorithm, the second gives you a diagram which provides a bit more detail, and the third gives you the annotations for the function (so you just have to fill in the code). You are recommended to start by trying to implement it without help, but use the dropdowns (in order) if this is too difficult.

We've also provided a few dropdowns to address specific technical errors that can arise from implementing this function. If you're having trouble, you can use these to help you debug. You should take some time with this function, as it's definitely the most important exercise to understand today.

In [15]:
def backprop(end_node: Tensor, end_grad: Tensor | None = None) -> None:
    """Accumulates gradients in the grad field of each leaf node.

    tensor.backward() is equivalent to backprop(tensor).

    end_node:
        The rightmost node in the computation graph. If it contains more than one element, end_grad
        must be provided.
    end_grad:
        A tensor of the same shape as end_node. Set to 1 if not specified and end_node has only one
        element.
    """
    # Get value of end_grad_arr
    end_grad_arr = np.ones_like(end_node.array) if end_grad is None else end_grad.array

    # Create dict to store gradients
    grads: dict[Tensor, Arr] = {end_node: end_grad_arr}

    # YOUR CODE HERE - iterate through the sorted computational graph, performing backprop algorithm
    processing_order: list[Tensor] = sorted_computational_graph(end_node)
    for node in processing_order:
        out_grad = grads.pop(node)

        if node.is_leaf:
            if node.requires_grad:
                node.grad = Tensor(out_grad) if node.grad is None else node.grad + out_grad
        else :
            for pos, parent in node.recipe.parents.items():
                back_fn = BACK_FUNCS.get_back_func(node.recipe.func, pos)
                grad = back_fn(out_grad, node.array, *node.recipe.args, **node.recipe.kwargs)
                grads[parent] = grad if (parent not in grads) else grads[parent] + grad


tests.test_backprop(Tensor)
tests.test_backprop_branching(Tensor)
tests.test_backprop_requires_grad_sum(Tensor)
tests.test_backprop_requires_grad_false(Tensor)
tests.test_backprop_float_arg(Tensor)

All tests in `test_backprop` passed!
All tests in `test_backprop_branching` passed!
All tests in `test_backprop_requires_grad_sum` passed!
All tests in `test_backprop_requires_grad_false` passed!
All tests in `test_backprop_float_arg` passed!


<details>
<summary>Help - I get AttributeError: 'NoneType' object has no attribute 'func'</summary>

This error is probably because you're trying to access `recipe.func` from the wrong node. Possibly, you're calling your backward functions using the parents nodes' `recipe.func`, rather than the node's `recipe.func`.
        
To explain further, suppose your computational graph is simply:

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/refs/heads/main/img/Screenshot%25202023-02-17%2520174308.png" width=320>

When you reach `b` in your backprop iteration, you should calculate the gradient wrt `a` (the only parent of `b`) and store it in your `grads` dictionary, as `grads[a]`. In order to do this, you need the backward function for `func1`, which is stored in the node `b` (recall that the recipe of a tensor can be thought of as a set of instructions for how that tensor was created).
</details>

<details>
<summary>Help - I get AttributeError: 'numpy.ndarray' object has no attribute 'array'</summary>

This might be because you've set `node.grad` to be an array, rather than a tensor. You should store gradients as tensors (think of PyTorch, where `tensor.grad` will have type `torch.Tensor`).
        
It's fine to store numpy arrays in the `grads` dictionary, but when it comes time to set a tensor's grad attribute, you should use a tensor.
</details>

<details>
<summary>Help - I get 'RuntimeError: bool value of Tensor with more than one value is ambiguous'.</summary>

This error is probably because your computational graph function checks whether a tensor is in a list. The way these classes are compared for equality is a bit funky, and using sets rather than lists should make this error go away (i.e. checking whether a tensor is in a set should be fine).
</details>

<details>
<summary>Help - I'm failing on the <code>test_backprop_requires_grad_sum</code> test and I don't know why.</summary>

This test is designed to spot cases where you're accidentally **overwriting the gradient with each backward fn call**, rather than summing them. Remember that if a node has multiple paths from itself to the end node, then that node's grad attribute should be the sum of the gradients from all those

</details>

<details>
<summary>Help - I'm stuck, and I need a template for the function.</summary>

You just need to fill in the code below the comments labelled (1) and (2).

```python
def backprop(end_node: Tensor, end_grad: Tensor | None = None) -> None:

    # Get value of end_grad_arr
    end_grad_arr = np.ones_like(end_node.array) if end_grad is None else end_grad.array

    # Create dict to store gradients
    grads: dict[Tensor, Arr] = {end_node: end_grad_arr}

    for node in sorted_computational_graph(end_node):
        outgrad = grads.pop(node)

        # (1) If this is a leaf node, then set/update the gradient if requires_grad
        ...

        # (2) If this isn't a leaf node, then iterate through this node's parents and update their values in the `grads`
        # dict, using the outgrad values returned from this node's backward function
        ...
```

</details>


<details><summary>Solution</summary>

```python
def backprop(end_node: Tensor, end_grad: Tensor | None = None) -> None:
    """Accumulates gradients in the grad field of each leaf node.

    tensor.backward() is equivalent to backprop(tensor).

    end_node:
        The rightmost node in the computation graph. If it contains more than one element, end_grad
        must be provided.
    end_grad:
        A tensor of the same shape as end_node. Set to 1 if not specified and end_node has only one
        element.
    """
    # Get value of end_grad_arr
    end_grad_arr = np.ones_like(end_node.array) if end_grad is None else end_grad.array

    # Create dict to store gradients
    grads: dict[Tensor, Arr] = {end_node: end_grad_arr}

    for node in sorted_computational_graph(end_node):
        # Get the outgrad from the grads dict
        outgrad = grads.pop(node)

        # (1) If it's a leaf node, then set/update gradient if requires_grad=True, and stop here.
        if node.is_leaf:
            if node.requires_grad:
                node.grad = Tensor(outgrad) if node.grad is None else node.grad + outgrad

        # (2) If not a leaf node then it must have a recipe, so we iterate through its parents and
        # update their grads.
        else:
            for argnum, parent in node.recipe.parents.items():
                # Get backward function, from the fwd function that created `node` from `parent`.
                back_fn = BACK_FUNCS.get_back_func(node.recipe.func, argnum)

                # Use it to compute the gradient we'll add onto parent from the path `parent -> node
                # -> ... -> end_node`.
                in_grad = back_fn(outgrad, node.array, *node.recipe.args, **node.recipe.kwargs)

                # Add this gradient to the grads dict (handling special case where parent is not in
                # grads yet).
                grads[parent] = in_grad if (parent not in grads) else grads[parent] + in_grad
```
</details>

# 3️⃣ Training on MNIST from scratch

> ##### Learning Objectives
>
> * Implement more forward and backward functions, including for indexing, summing, and matrix multiplication
> * Learn how to build higher-level abstractions like parameters and modules on top of individual functions and tensors
> * Complete the process of building up a neural network from scratch and training it via gradient descent.

Congrats on implementing backprop! Soon we'll be able to train a full model from scratch, but first we'll go through a bunch of backward functions which will be necessary for training (as well as ones that cover some interesting cases). These should be a lot like your `log_back` and `multiply_back0`, `multiplyback1` examples earlier.

## More backward functions!

> Note - some of these exercises can get a bit repetitive, and so **you're welcome to skip through many of these exercises if you don't find them interesting, and/or you're pressed for time.** The exercises in the section "Parameters & Modules" and beyond are much more conceptually valuable.
> 
> Additionally, most of these functions can be implemented simply in 1 or 2 lines, so if you find yourself writing a lot more than that then you might want to look at the solution instead.

### Exercise - `negative`

> ```yaml
> Difficulty: 🔴⚪⚪⚪⚪
> Importance: 🔵⚪⚪⚪⚪
> 
> You should spend up to ~5 minutes on this exercise (it's not a trick question, it is as simple as it looks!).
> ```

`torch.negative` just performs `-x` elementwise. Make your own version `negative` using `wrap_forward_fn`. Note, you don't need to worry about unbroadcasting here because `np.negative` won't change the input shape (technically it can since `np.negative(x, out)` is actually the negative version of `x` broadcasted to the shape of `out`, but we won't be using it in this way during our exercises).

In [16]:
def negative_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
    """Backward function for f(x) = -x elementwise."""
    return grad_out * -1


negative = wrap_forward_fn(np.negative)
BACK_FUNCS.add_back_func(np.negative, 0, negative_back)

tests.test_negative_back(Tensor)

All tests in `test_negative_back` passed!


<details><summary>Solution</summary>

```python
def negative_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
    """Backward function for f(x) = -x elementwise."""
    return -grad_out
```
</details>

### Exercise - `exp`

> ```yaml
> Difficulty: 🔴⚪⚪⚪⚪
> Importance: 🔵🔵⚪⚪⚪
> 
> You should spend up to 5-10 minutes on this exercise.
> ```

Make your own version of `torch.exp`. The backward function should express the result in terms of the `out` parameter - this more efficient than expressing it in terms of `x`.

In [17]:
def exp_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
    """Backward function for f(x) = exp(x) elementwise."""
    return grad_out * out


exp = wrap_forward_fn(np.exp)
BACK_FUNCS.add_back_func(np.exp, 0, exp_back)

tests.test_exp_back(Tensor)

All tests in `test_exp_back` passed!


<details><summary>Solution</summary>

```python
def exp_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
    """Backward function for f(x) = exp(x) elementwise."""
    return out * grad_out
```
</details>

### Exercise - `reshape`

> ```yaml
> Difficulty: 🔴🔴⚪⚪⚪
> Importance: 🔵⚪⚪⚪⚪
> 
> You should spend up to 5-10 minutes on this exercise.
> ```

`reshape` is a bit different than the functions we've dealt with so far: it changes the shape of the tensor, not its values. In other words, the backward function needs to be able to map from the gradient $\partial L / \partial \mathbf{x_r}$ to $\partial L / \partial \mathbf{x}$, where $\mathbf{x_r}$ is the reshaped version of $\mathbf{x}$.

Depending how you wrote `wrap_forward_fn` and `backprop`, you might need to go back and adjust them to handle this - if you're failing tests but think your implementation is correct, we recommend you go back to these functions and check them.

This function should just be a single line.

In [18]:
def reshape_back(grad_out: Arr, out: Arr, x: Arr, new_shape: tuple) -> Arr:
    """Backward function for torch.reshape."""
    return np.reshape(grad_out, x.shape)


reshape = wrap_forward_fn(np.reshape)
BACK_FUNCS.add_back_func(np.reshape, 0, reshape_back)

tests.test_reshape_back(Tensor)

All tests in `test_reshape_back` passed!


<details>
<summary>Solution (and explanation)</summary>

Explanation: the reshape operation that takes us from the tensor $\frac{\partial L}{\partial \mathbf{x_r}}$ to $\frac{\partial L}{\partial \mathbf{x}}$ is exactly the inverse of the forward reshape operation that produced $\mathbf{x_r}$ from $\mathbf{x}$. In other words, we want to to take `grad_out` and reshape it back to the shape of `x`.

```python
def reshape_back(grad_out: Arr, out: Arr, x: Arr, new_shape: tuple) -> Arr:
    """Backward function for torch.reshape."""
    return np.reshape(grad_out, x.shape)
```

</details>

### Exercise - `permute`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵⚪⚪⚪⚪
> 
> You should spend up to 5-10 minutes on this exercise.
> ```

In NumPy, the equivalent of `torch.permute` is called `np.transpose`, so we will wrap that. Permute is somewhat similar to reshape, but the difference is that it does actually change the order of elements in the underlying array.

Hint - just like with `reshape`, the inverse of a transposition is another transposition. You might find the function `np.argsort` useful for getting the inverse transposition.

This function should also just be a single line.

In [19]:
def permute_back(grad_out: Arr, out: Arr, x: Arr, axes: tuple) -> Arr:
    """
    Backward function for torch.permute. Works by inverting the transposition in the forward
    function.
    """
    return np.transpose(grad_out, np.argsort(axes))


BACK_FUNCS.add_back_func(np.transpose, 0, permute_back)
permute = wrap_forward_fn(np.transpose)

tests.test_permute_back(Tensor)

All tests in `test_permute_back` passed!


<details>
<summary>Solution (and explanation)</summary>

The inverse of transposing with `axes` is transposing using `np.argsort(axes)`. To see this: the forward transpose will send axis `j` to axis `axes[j]`, and so we want the inverse transposition `axes_inv` to satisfy `axes_inv[axes[j]] = j`, in other words **the indices of `axes_inv` should sort `axes`** - this is exactly what `np.argsort` does.

```python
def permute_back(grad_out: Arr, out: Arr, x: Arr, axes: tuple) -> Arr:
    """
    Backward function for torch.permute. Works by inverting the transposition in the forward
    function.
    """
    return np.transpose(grad_out, np.argsort(axes))
```

</details>

### Exercise - `sum`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵⚪⚪⚪
> 
> You should spend up to 20-30 minutes on this exercise.
> ```

The output can also be smaller than the input, such as when calling `torch.sum`.

Recall that when we looked at broadcasting, the backwards operation was summing over the broadcasted dimensions. This is because a broadcast operation effectively copies our tensor, giving it more gradient paths that we need to sum over. Similarly, the backwards operation for summing is broadcasting. We can intuitively see this as follows: if we have some value `L = L(x_summed)` where `x_summed` is the result of summing `x` over some number of dimensions, then editing `x_summed[i, j, ...] += delta` has the same downstream effect as editing any one of the `x` values which were summed over to get `x_summed[i, j, ...]`. So to get the gradient of `L` wrt `x`, we need to copy (broadcast) the gradient of `L` wrt `x_summed` up to the full size of `x`.

Implementing `sum_back` should have 2 steps:

1. **Adding new dims if they were summed over with `keepdim=False`**. You can do this with `np.expand_dims`, for example if `arr` has shape `(2, 3)` then `np.expand_dims(arr, (0, 2))` has shape `(1, 2, 1, 3)`, i.e. it's a new tensor with dummy dimensions created at indices 0 and 2.
2. **Broadcasting along dims that were summed over**. Since after step (1) you've effectively reduced to the `keepdim=True` case, you can now use `np.broadcast_to` to get the correct shape.

Note, if you get weird errors that you can't explain, and these exceptions don't even go away when you use the solutions provided, this could mean that your implementation of `wrap_forward_fn` was wrong in a way which wasn't picked up by the tests. You should return to this function and try to fix it (or just use the solution).

In [20]:
def sum_back(grad_out: Arr, out: Arr, x: Arr, dim=None, keepdim=False):
    """Backward function for torch.sum"""
    
    # print(f"\nx.shape: {x.shape} \t x:{x} \ndim: {dim} \nkeepdim: {keepdim}")
    # print(f"out.shape: {out.shape} \t  out: {out}")
    
    if (not keepdim) and (dim is not None):
        grad_out = np.expand_dims(grad_out, dim)
    
    # print(f"out.shape: {out.shape} \t  out: {out}")
    
    grad_out = np.broadcast_to(grad_out, x.shape)
    # print(f"out.shape: {out.shape} \t  out: {out}")
    
    # print("grad_out:", grad_out)
    # print("ret:", grad_out * out)
    return grad_out 


def _sum(x: Arr, dim=None, keepdim=False) -> Arr:
    """Like torch.sum, calling np.sum internally."""
    return np.sum(x, axis=dim, keepdims=keepdim)


sum = wrap_forward_fn(_sum)
BACK_FUNCS.add_back_func(_sum, 0, sum_back)

tests.test_sum_keepdim_false(Tensor)
tests.test_sum_keepdim_true(Tensor)
tests.test_sum_dim_none(Tensor)
tests.test_sum_nonscalar_grad_out(Tensor)

All tests in `test_sum_keepdim_false` passed!
All tests in `test_sum_keepdim_true` passed!
All tests in `test_sum_dim_none` passed!
All tests in `test_sum_nonscalar_grad_out` passed!


<details>
<summary>Help - I'm not sure how to handle the case where <code>dim=None</code>.</summary>

You can actually handle this pretty easily - if `dim=None` then `grad_out` will be a scalar, so it's always fine to broadcast it along the dims that were summed over! This means you can skip step (1), i.e. this step only needs to handle the case where `keepdim=False` and `dim` is not `None`.

</details>

<details>
<summary>Help - I get the error "Encountered error when running `backward` in the test for nonscalar grad_out."</summary>

This error is likely due to the fact that you're expanding your tensor in a way that doesn't refer to the dimensions being summed over (i.e. the `dim` argument).

Remember that in the previous exercise we assumed that the tensors were broadcastable with each other, and our functions could just internally call `np.broadcast_to` as a result. But here, one tensor is the sum over another tensor's dimensions, and if `keepdim=False` then they might not broadcast. For instance, if `x.shape = (2, 5)`, `out = x.sum(dim=1)` has shape `(2,)` and `grad_out.shape = (2,)`, then the tensors `grad_out` and `x` are not broadcastable.

How can you carefully handle the case where `keepdim=False` and `dim` doesn't just refer to dimensions at the start of the tensor? (Hint - try and use `np.expand_dims`).

</details>


<details><summary>Solution</summary>

```python
def sum_back(grad_out: Arr, out: Arr, x: Arr, dim=None, keepdim=False):
    """Backward function for torch.sum"""
    # Step (1): if keepdim=False, then we need to add back in dims, so grad_out and x have the
    # same number of dims. We don't bother with the dim=None case, since then grad_out is a scalar
    # and this will be handled by our broadcasting in step (2).
    if (not keepdim) and (dim is not None):
        grad_out = np.expand_dims(grad_out, dim)

    # Step (2): repeat grad_out along the dims over which x was summed
    return np.broadcast_to(grad_out, x.shape)
```
</details>

### Elementwise add, subtract, divide

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵⚪⚪⚪
> 
> You should spend up to 10-15 minutes on this exercise.
> ```

These are exactly analogous to the multiply case. Note that Python and NumPy have the notion of "floor division", which is a truncating integer division as in `7 // 3 = 2`. You can ignore floor division: - we only need the usual floating point division which is called "true division".

Use lambda functions to define and register the backward functions each in one line. We've given you the first one.

In [21]:
add = wrap_forward_fn(np.add)
subtract = wrap_forward_fn(np.subtract)
true_divide = wrap_forward_fn(np.true_divide)

BACK_FUNCS.add_back_func(np.add, 0, lambda grad_out, out, x, y: unbroadcast(grad_out, x))
# YOUR CODE HERE - continue adding to BACK_FUNCS, for each of the 3 functions & both arg orders
BACK_FUNCS.add_back_func(np.add, 1, lambda grad_out, out, x, y: unbroadcast(grad_out, y))
BACK_FUNCS.add_back_func(np.subtract, 0, lambda grad_out, out, x, y: unbroadcast(grad_out, x))
BACK_FUNCS.add_back_func(np.subtract, 1, lambda grad_out, out, x, y: unbroadcast(-grad_out, y))
BACK_FUNCS.add_back_func(np.true_divide, 0, lambda grad_out, out, x, y: unbroadcast(grad_out / y, x))
BACK_FUNCS.add_back_func(np.true_divide, 1, lambda grad_out, out, x, y: unbroadcast(grad_out * -x/(y**2), y))

tests.test_add_broadcasted(Tensor)
tests.test_subtract_broadcasted(Tensor)
tests.test_truedivide_broadcasted(Tensor)

All tests in `test_add_broadcasted` passed!
All tests in `test_subtract_broadcasted` passed!
All tests in `test_truedivide_broadcasted` passed!


<details><summary>Solution</summary>

```python
BACK_FUNCS.add_back_func(np.add, 0, lambda grad_out, out, x, y: unbroadcast(grad_out, x))
BACK_FUNCS.add_back_func(np.add, 1, lambda grad_out, out, x, y: unbroadcast(grad_out, y))
BACK_FUNCS.add_back_func(np.subtract, 0, lambda grad_out, out, x, y: unbroadcast(grad_out, x))
BACK_FUNCS.add_back_func(np.subtract, 1, lambda grad_out, out, x, y: unbroadcast(-grad_out, y))
BACK_FUNCS.add_back_func(
    np.true_divide, 0, lambda grad_out, out, x, y: unbroadcast(grad_out / y, x)
)
BACK_FUNCS.add_back_func(
    np.true_divide, 1, lambda grad_out, out, x, y: unbroadcast(grad_out * (-x / y**2), y)
)
```
</details>

### Indexing

If we have the gradient of `L` wrt `x[index]`, what is the gradient of `L` wrt `x`? The answer is it'll be an array of zeros, filled in with the values of `dL/dx[index]` at the appropriate index positions. For example, if `x = [1, 2, 3]` and `L = x[0]`, then we trivially have `dL/dx[0] = 1`, and we can compute `dL/dx = [1, 0, 0]` in this way.

In its full generality, exactly how you can index a `torch.Tensor` is really complicated and there are quite a few cases to handle separately. Our implementation only handles 2 cases:

- The index is an integer or tuple of integers.
- The index is a tuple of (array or Tensor) representing coordinates. Each array is 1D and of equal length. Some coordinates may be repeated. This is [Integer array indexing](https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing).

This latter case is very important, because it describes how we index correct logprobs / probabilities. For example if we're training a classifier and we have tensors `logprobs.shape = (batch_size, n_classes)` and `targets.shape = (n_classes,)` then we index the correct logprobs using `logprobs[arange(batch_size), targets]` (note that the `arange` function has been given to you earlier, when we defined the `Tensor` class - it's a simple wrapper around `np.arange`).

In [22]:
Index = int | tuple[int, ...] | tuple[Arr] | tuple[Tensor]


def coerce_index(index: Index):
    """Helper function: converts array of tensors to array of numpy arrays."""
    if isinstance(index, tuple) and all(isinstance(i, Tensor) for i in index):
        return tuple([i.array for i in index])
    else:
        return index


def _getitem(x: Arr, index: Index) -> Arr:
    """Like x[index] when x is a torch.Tensor."""
    return x[coerce_index(index)]


def getitem_back(grad_out: Arr, out: Arr, x: Arr, index: Index):
    """
    Backwards function for _getitem.

    Hint: use np.add.at(a, indices, b)
    This function works just like a[indices] += b, except that it allows for repeated indices.
    """
    new_grad_out = np.full_like(x, 0)
    np.add.at(new_grad_out, coerce_index(index), grad_out)
    return new_grad_out


getitem = wrap_forward_fn(_getitem)
BACK_FUNCS.add_back_func(_getitem, 0, getitem_back)

### Non-Differentiable Functions

For functions like `torch.argmax` or `torch.eq`, there's no sensible way to define gradients with respect to the input tensor. For these, we will still use `wrap_forward_fn` because we still need to unbox the arguments and box the result, but by passing `is_differentiable=False` we can avoid doing any unnecessary computation.

We've given you this one as an example:

In [23]:
def _argmax(x: Arr, dim=None, keepdim=False):
    """Like torch.argmax."""
    result = np.argmax(x, axis=dim)
    if keepdim:
        return np.expand_dims(result, axis=([] if dim is None else dim))
    return result


argmax = wrap_forward_fn(_argmax, is_differentiable=False)

a = Tensor([1.0, 0.0, 3.0, 4.0], requires_grad=True)
b = a.argmax()
assert not b.requires_grad
assert b.recipe is None
assert b.item() == 3

### In-Place Operations

Supporting in-place operations introduces substantial complexity and generally doesn't help performance that much. The problem is that if any of the inputs used in the backward function have been modified in-place since the forward pass, then the backward function will incorrectly calculate using the modified version. PyTorch will warn you when this causes a problem with the error "RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.".

Note - you don't have to fill anything in here; just run the cell. If you're curious, you can implement inplace operations as a bonus exercise at the end, but for now we just warn against inplace operations unless we specify otherwise.

In [24]:
def add_(x: Tensor, other: Tensor, alpha: float = 1.0) -> Tensor:
    """Like torch.add_. Compute x += other * alpha in-place and return tensor."""
    np.add(x.array, other.array * alpha, out=x.array)
    return x


def sub_(x: Tensor, other: Tensor, alpha: float = 1.0) -> Tensor:
    """Like torch.sub_. Compute x -= other * alpha in-place and return tensor."""
    np.subtract(x.array, other.array * alpha, out=x.array)
    return x


def safe_example():
    """This example should work properly."""
    a = Tensor([0.0, 1.0, 2.0, 3.0], requires_grad=True)
    b = Tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True)
    a.add_(b)
    c = a * b
    c.sum().backward()
    assert a.grad is not None and np.allclose(a.grad.array, [2.0, 3.0, 4.0, 5.0])
    assert b.grad is not None and np.allclose(b.grad.array, [2.0, 4.0, 6.0, 8.0])


def unsafe_example():
    """
    This example is expected to compute the wrong gradients, because dc/db is calculated using the
    modified a.
    """
    a = Tensor([0.0, 1.0, 2.0, 3.0], requires_grad=True)
    b = Tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True)
    c = a * b
    a.add_(b)
    c.sum().backward()
    if a.grad is not None and np.allclose(a.grad.array, [2.0, 3.0, 4.0, 5.0]):
        print("Grad wrt a is OK!")
    else:
        print("Grad wrt a is WRONG!")
    if b.grad is not None and np.allclose(b.grad.array, [0.0, 1.0, 2.0, 3.0]):
        print("Grad wrt b is OK!")
    else:
        print("Grad wrt b is WRONG!")


safe_example()
unsafe_example()

Grad wrt a is OK!
Grad wrt b is WRONG!


### Mixed Scalar-Tensor Operations

You may have been wondering why our `Tensor` class has to define both `__mul__` and `__rmul__` magic methods.

Without `__rmul__` defined, executing `2 * a` when `a` is a `Tensor` would try to call `2.__mul__(a)`, and the built-in class `int` would be confused about how to handle this.

Since we have defined `__rmul__` for you at the start, and you implemented multiply to work with floats as arguments, the following should "just work".

In [25]:
a = Tensor([0, 1, 2, 3], requires_grad=True)
(a * 2).sum().backward()
b = Tensor([0, 1, 2, 3], requires_grad=True)
(2 * b).sum().backward()
assert a.grad is not None
assert b.grad is not None
assert np.allclose(a.grad.array, b.grad.array)

### Exercise - `max`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵⚪⚪⚪
> 
> You should spend up to 10-15 minutes on this exercise.
> ```

Since this is an elementwise function, we can think about the scalar case. For scalar $x$, $y$, the derivative for $\max(x, y)$ wrt $x$ is 1 when $x > y$ and 0 when $x < y$. What should happen when $x = y$?

Intuitively, since $\max(x, x)$ is equivalent to the identity function which has a derivative of 1 wrt $x$, it makes sense for the sum of our partial derivatives wrt $x$ and $y$ to also therefore total 1. The convention used by PyTorch is to split the derivative evenly between the two arguments. We will follow this behavior for compatibility, but it's just as legitimate to say it's 1 wrt $x$ and 0 wrt $y$, or some other arbitrary combination that sums to one.

<details>
<summary>Help - I'm not sure how to implement this function.</summary>

Try returning `grad_out * bool_sum`, where `bool_sum` is an array constructed from the sum of two boolean arrays.

You can alternatively use `np.where`.
</details>

<details>
<summary>Help - I'm passing the first test but not the second.</summary>

This probably means that you haven't implemented `unbroadcast`. You'll need to do this, to get `grad_out` into the right shape before you use it in `np.where`.
</details>

In [26]:
def maximum_back0(grad_out: Arr, out: Arr, x: Arr, y: Arr):
    """Backwards function for max(x, y) wrt x."""
    bool_sum = np.zeros(out.shape)
    bool_sum += 0.5 * (x == y)
    bool_sum += (x > y)
    return unbroadcast(grad_out * bool_sum, x)


def maximum_back1(grad_out: Arr, out: Arr, x: Arr, y: Arr):
    """Backwards function for max(x, y) wrt y."""
    bool_sum = np.zeros(out.shape)
    bool_sum += 0.5 * (x == y)
    bool_sum += (x < y)
    return unbroadcast(grad_out * bool_sum, y)


maximum = wrap_forward_fn(np.maximum)
BACK_FUNCS.add_back_func(np.maximum, 0, maximum_back0)
BACK_FUNCS.add_back_func(np.maximum, 1, maximum_back1)

tests.test_maximum(Tensor)
tests.test_maximum_broadcasted(Tensor)

All tests in `test_maximum` passed!
All tests in `test_maximum_broadcasted` passed!


<details><summary>Solution</summary>

```python
def maximum_back0(grad_out: Arr, out: Arr, x: Arr, y: Arr):
    """Backwards function for max(x, y) wrt x."""
    bool_sum = (x > y) + 0.5 * (x == y)
    return unbroadcast(grad_out * bool_sum, x)


def maximum_back1(grad_out: Arr, out: Arr, x: Arr, y: Arr):
    """Backwards function for max(x, y) wrt y."""
    bool_sum = (x < y) + 0.5 * (x == y)
    return unbroadcast(grad_out * bool_sum, y)
```
</details>

### Exercise - functional `ReLU`

> ```yaml
> Difficulty: 🔴⚪⚪⚪⚪
> Importance: 🔵⚪⚪⚪⚪
> 
> You should spend ~5 minutes on this exercise.
> ```

A simple and correct ReLU function can be defined in terms of your maximum function. Note the PyTorch version also supports in-place operation, which we are punting to the bonus section for now.

Again, at $x = 0$ your derivative could reasonably be anything between 0 and 1 inclusive, but we've followed PyTorch in making it 0.5. This means you can just use the `maximum` function defined above!

In [27]:
def relu(x: Tensor) -> Tensor:
    """Like torch.nn.function.relu(x, inplace=False)."""
    return maximum(x, 0.0)


tests.test_relu(Tensor)

All tests in `test_relu` passed!


<details><summary>Solution</summary>

```python
def relu(x: Tensor) -> Tensor:
    """Like torch.nn.function.relu(x, inplace=False)."""
    return maximum(x, 0.0)
```
</details>

### Exercise - 2D `matmul`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 20-25 minutes on this exercise.
> ```

Implement your version of `torch.matmul`, restricting it to the simpler case where both inputs are 2D (this means we don't need to worry about unbroadcasting or anything).

Note - althought the solution to this exercise is very short (just one line), you may find the actual mathematical derivation a bit tricky. We've given hints to help you, which we recommend using if you're stuck.

In [28]:
def _matmul2d(x: Arr, y: Arr) -> Arr:
    """Matrix multiply restricted to the case where both inputs are exactly 2D."""
    return x @ y


def matmul2d_back0(grad_out: Arr, out: Arr, x: Arr, y: Arr) -> Arr:
    return grad_out @ y.T

def matmul2d_back1(grad_out: Arr, out: Arr, x: Arr, y: Arr) -> Arr:
    return x.T @ grad_out


matmul = wrap_forward_fn(_matmul2d)
BACK_FUNCS.add_back_func(_matmul2d, 0, matmul2d_back0)
BACK_FUNCS.add_back_func(_matmul2d, 1, matmul2d_back1)

tests.test_matmul2d(Tensor)

All tests in `test_matmul2d` passed!


<details>
<summary>Help - I need a hint about the math</summary>

Let $X$, $Y$ and $M$ denote the variables `x`, `y` and `out`, so we have the matrix relation $M = XY$. The object `grad_out` is a tensor with elements `grad_out[p, q]` $ = \frac{\partial L}{\partial M_{p q}}$.

The output of `matmul2d_back0` should be the gradient of $L$ wrt $X$, i.e. it should have elements $\frac{\partial L}{\partial X_{i j}}$. Can you write this in terms of the elements of `x`, `y`, `out` and `grad_out`?
</details>

<details>
<summary>Help - I need the math explained</summary>

We can write $\frac{\partial L}{\partial X_{i j}}$ as:

$$
\begin{aligned}
\frac{\partial L}{\partial X_{i j}} &=\sum_{pq} \frac{\partial L}{\partial M_{p q}} \frac{\partial M_{p q}}{\partial X_{i j}} \\
&=\sum_{pqr} \left[\text{ grad\_out }\right]_{p q} \frac{\partial (X_{p r} Y_{r q})}{\partial X_{i j}} \\
&=\sum_{q} \left[\text{ grad\_out }\right]_{iq} Y_{j q} \\
&= \left[\text{ grad\_out } \times Y^{\top}\right]_{ij}
\end{aligned}
$$

where the second line follows because $M_{pq} = \sum_r X_{pr} Y_{rq}$ (and we can rearrange the summands), and the third line follows because $\frac{\partial{X_{pr}}}{X_{ij}} = 1 \text{ if } (p, r) = (i, j), \text{ else } 0$.


In other words, the `x.grad` attribute should be is `grad_out @ y.T`.

You can calculate the gradient wrt `y` in a similar way - we leave this as an exercise for the reader.
</details>


<details><summary>Solution</summary>

```python
def _matmul2d(x: Arr, y: Arr) -> Arr:
    """Matrix multiply restricted to the case where both inputs are exactly 2D."""
    return x @ y


def matmul2d_back0(grad_out: Arr, out: Arr, x: Arr, y: Arr) -> Arr:
    return grad_out @ y.T


def matmul2d_back1(grad_out: Arr, out: Arr, x: Arr, y: Arr) -> Arr:
    return x.T @ grad_out
```
</details>

## Parameters & Modules

We've now written enough backwards passes that we can go up a layer and write our own `nn.Parameter` and `nn.Module`. These are important abstractions that help us building up neural networks.

Below is a simple implementation of `Parameter`. It is itself a `Tensor`, shares storage with the provided `Tensor` and requires_grad is `True` by default - that's it! Make sure you understand the code being run in this cell to test the functionality of this class.

In [29]:
class Parameter(Tensor):
    def __init__(self, tensor: Tensor, requires_grad=True):
        """Share the array with the provided tensor."""
        return super().__init__(tensor.array, requires_grad=requires_grad)

    def __repr__(self):
        return f"Parameter containing:\n{super().__repr__()}"


x = Tensor([1.0, 2.0, 3.0])
p = Parameter(x)
assert p.requires_grad
assert p.array is x.array
assert (
    repr(p)
    == "Parameter containing:\nTensor(array([1., 2., 3.], dtype=float32), requires_grad=True)"
)
x.add_(Tensor(np.array(2.0)))
assert np.allclose(p.array, np.array([3.0, 4.0, 5.0])), (
    "in-place modifications to the original tensor should affect the parameter"
)

Just like `torch.Tensor`, the `nn.Module` class has a lot of functionality which we mostly don't care about today. We will just implement enough to get our network training.

Below is a simple implementation. We'll explain it a bit below (if you're already experienced in Python then this might be obvious to you and you can skip it).

- **Single-underscore attributes** are a notational convention; they're not treated differently by Python but they're used to indicate that the attribute is private and shouldn't be accessed directly by anyone using the class.
    - `_modules` is a dict mapping module names to module objects. The `modules` method returns an iterator over these modules.
    - `_parameters` is similar, with added recursion to include submodule parameters.
- **Double-underscore attributes** are special methods that determine how your class instance behaves when you use certain syntax.
    - `__call__` determines what the module does when you call it like a function. In this case, `module(*args, **kwargs)` calls `module.forward(**args, **kwargs)` - this is why we only ever need to implement `forward` in the modules we've written so far.
    - `__setattr__` manages attribute setting (i.e. running `module.attr = value` actually calls `module.__setattr__("attr", value)`). The default behaviour is to add the attribute to `self.__dict__`, but we've added custom logic so modules & parameters are also added to `self._modules` and `self._parameters` respectively - this is basically how logic like `module.parameters()` can work.
        - Note that there's a related method `__getattr__` which specifies attribute getting behaviour when lookup in `self.__dict__` fails.

In [30]:
class Module:
    _modules: dict[str, "Module"]
    _parameters: dict[str, Parameter]

    def __init__(self):
        self._modules: dict[str, "Module"] = {}
        self._parameters: dict[str, Parameter] = {}

    def modules(self) -> Iterator["Module"]:
        """Return the direct child modules of this module, not including self."""
        yield from self._modules.values()

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        """
        Return an iterator over Module parameters.

        recurse: if True, the iterator includes parameters of submodules, recursively.
        """
        yield from self._parameters.values()
        if recurse:
            for mod in self.modules():
                yield from mod.parameters(recurse=True)

    def __setattr__(self, key: str, val: Any) -> None:
        """
        If val is a Parameter or Module, store it in the appropriate _parameters or _modules dict.
        Otherwise, call __setattr__ from the superclass.
        """
        if isinstance(val, Parameter):
            self._parameters[key] = val
        elif isinstance(val, Module):
            self._modules[key] = val
        super().__setattr__(key, val)

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def forward(self):
        raise NotImplementedError("Subclasses must implement forward!")

    def __repr__(self):
        _indent = lambda s_, nSpaces: re.sub("\n", "\n" + (" " * nSpaces), s_)
        lines = [f"({key}): {_indent(repr(module), 2)}" for key, module in self._modules.items()]
        return "".join(
            [
                self.__class__.__name__ + "(",
                "\n  " + "\n  ".join(lines) + "\n" if lines else "",
                ")",
            ]
        )


class TestInnerModule(Module):
    def __init__(self):
        super().__init__()
        self.param1 = Parameter(Tensor([1.0]))
        self.param2 = Parameter(Tensor([2.0]))


class TestModule(Module):
    def __init__(self):
        super().__init__()
        self.inner = TestInnerModule()
        self.param3 = Parameter(Tensor([3.0]))


mod = TestModule()
assert list(mod.modules()) == [mod.inner]
assert list(mod.parameters()) == [mod.param3, mod.inner.param1, mod.inner.param2]
print("Manually verify that the repr looks reasonable:")
print(mod)
print("All tests for `Module` passed!")

Manually verify that the repr looks reasonable:
TestModule(
  (inner): TestInnerModule()
)
All tests for `Module` passed!


### Exercise - implement `Linear`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵🔵🔵⚪
> 
> You should spend up to 20-25 minutes on this exercise.
> ```

Now, let's go one level of abstraction higher and create a `Linear` module. This should inherit from `Module` and have `__init__` & `forward` methods just like your linear module inheriting from `nn.Module` in previous exercises. In fact, your code can probably be extremely similar to the time you implemented `Linear` in the earlier exercises, except you'll need to use methods we've defined already. You should be able to do everything you need in `forward` using just the matmul operator `@`, the transpose operator `.T` (which is equivalent to `.permute(-1, -2)` as you can see in the tensor class above) and standard tensor addition `+`.

To restate the task in case you don't remember it from the previous exercises, you should:

- Define `self.weight` and `self.bias` in `__init__`, with both tensors having a uniform distribution in the range `[-sf, sf]` where `sf = 1/sqrt(in_features)`,
- Write the appropriate affine operation in `forward`, i.e. multiplying by `self.weight` and adding `self.bias` if it exists.

Don't forget to wrap your weights as `Parameter(Tensor(...))`.

In [31]:
class Linear(Module):
    weight: Parameter
    bias: Parameter | None

    def __init__(self, in_features: int, out_features: int, bias=True):
        """
        A simple linear (technically, affine) transformation.

        The fields should be named `weight` and `bias` for compatibility with PyTorch.
        If `bias` is False, set `self.bias` to None.
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # w = np.random.random((out_features, in_features)) / np.sqrt(in_features)
        # self.weight = Parameter(Tensor(w))

        # if bias:
        #     b = np.random.random(out_features) / np.sqrt(in_features)
        #     self.bias = Parameter(Tensor(b))
        # else:
        #     self.bias = None

        sf = in_features**-0.5
        self.weight = Parameter(Tensor(sf * (2 * np.random.rand(out_features, in_features) - 1)))
        self.bias = Parameter(Tensor(sf * (2 * np.random.rand(out_features) - 1))) if bias else None

    def forward(self, x: Tensor) -> Tensor:
        """
        x: shape (*, in_features)
        Return: shape (*, out_features)
        """
        out = x @ self.weight.T
        if self.bias is not None:
            out = out + self.bias
        return out

    def extra_repr(self) -> str:
        return (
            f"in_features={self.in_features}, out_features={self.out_features}, "
            f"bias={self.bias is not None}"
        )


linear = Linear(3, 4)
assert isinstance(linear.weight, Tensor)
assert linear.weight.requires_grad

input = Tensor([[1.0, 2.0, 3.0]])
output = linear(input)
assert output.requires_grad

expected_output = input @ linear.weight.T + linear.bias
np.testing.assert_allclose(output.array, expected_output.array)

print("All tests for `Linear` passed!")

All tests for `Linear` passed!


<details><summary>Solution</summary>

```python
class Linear(Module):
    weight: Parameter
    bias: Parameter | None

    def __init__(self, in_features: int, out_features: int, bias=True):
        """
        A simple linear (technically, affine) transformation.

        The fields should be named `weight` and `bias` for compatibility with PyTorch.
        If `bias` is False, set `self.bias` to None.
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        sf = in_features**-0.5
        self.weight = Parameter(Tensor(sf * (2 * np.random.rand(out_features, in_features) - 1)))
        self.bias = Parameter(Tensor(sf * (2 * np.random.rand(out_features) - 1))) if bias else None

    def forward(self, x: Tensor) -> Tensor:
        """
        x: shape (*, in_features)
        Return: shape (*, out_features)
        """
        out = (
            x @ self.weight.T
        )  # transpose has been defined as .permute(-1, -2), see the `Tensor` class
        if self.bias is not None:
            out = out + self.bias
        return out

    def extra_repr(self) -> str:
        return (
            f"in_features={self.in_features}, out_features={self.out_features}, "
            f"bias={self.bias is not None}"
        )
```
</details>

Finally, for the sake of completeness, we'll define a `ReLU` module:

In [32]:
class ReLU(Module):
    def forward(self, x: Tensor) -> Tensor:
        return relu(x)

Now we can define a MLP suitable for classifying MNIST, with zero PyTorch dependency!

In [33]:
class MLP(Module):
    def __init__(self):
        super().__init__()
        self.linear1 = Linear(28 * 28, 64)
        self.linear2 = Linear(64, 64)
        self.relu1 = ReLU()
        self.relu2 = ReLU()
        self.output = Linear(64, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = x.reshape((x.shape[0], 28 * 28))
        x = self.relu1(self.linear1(x))
        x = self.relu2(self.linear2(x))
        x = self.output(x)
        return x

### Exercise - implement `cross_entropy`

> ```yaml
> Difficulty: 🔴🔴🟠⚪⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 10-15 minutes on this exercise.
> ```

Make use of your integer array indexing to implement `cross_entropy`. See the documentation page [here](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html).

We discussed this briefly in the section on indexing earlier, but the kind of indexing you should be doing on your logprobs is `logprobs[range(batch_size), true_labels]`, since this is equivalent to returning the vector of length `batch_size` with elements `[logprobs[0, true_labels[0]], logprobs[1, true_labels[1]], ...]`. Rather than using `range`, you should be using the `arange` function we've provided for you (this is equivalent to torch's `torch.arange` function, and is defined just below the `Tensor` class).

Note - if you're using the `exp` function, it's usually good to make your implementation numerically stable (since taking the exponential of large numbers is prone to overflow). The common solution here is to subtract the maximum value of the tensor from all elements. However, you don't need to worry about that here (consider it a bonus exercise).

In [34]:
def cross_entropy(logits: Tensor, true_labels: Tensor) -> Tensor:
    """Like torch.nn.functional.cross_entropy with reduction='none'.

    logits: shape (batch, classes)
    true_labels: shape (batch,). Each element is the index of the correct label in the logits.

    Return: shape (batch, ) containing the per-example loss.
    """

    logprobs = logits - logits.exp().sum(dim=-1, keepdim=True).log()
    return -logprobs[arange(0, logits.shape[0]), true_labels]


# tests.test_cross_entropy(Tensor, cross_entropy)
logits = Tensor(
    [
        [0, -100, -100, -100],  # equivalent to certainty of class=0
        [1 / 4, 1 / 4, 1 / 4, 1 / 4],  # uniform over all classes
        [0, 0, 0, -100],  # equivalent to uniform over first 3 classes
        [1000, 0, 0, 0],  # unstable test case
    ]
)
true_labels = Tensor([0, 0, 0, 0])
expected = Tensor([0.0, np.log(4), np.log(3), 0])
# with warnings.catch_warnings():
#     warnings.simplefilter("ignore", RuntimeWarning)

# First test: numerically stable
# print("Testing for numerically stable cases ... ", end="")
actual = cross_entropy(logits[:3], true_labels[:3])
np.testing.assert_allclose(actual.array, expected[:3].array)
# print("passed!")

# # Second test: unstable (will generate nans if not handled correctly)
# print("Testing for numerically unstable cases ... ", end="")
# actual = cross_entropy(logits[3:], true_labels[3:])
# np.testing.assert_allclose(actual.array, expected[3:].array)
# print("passed!")

print("All tests in `test_cross_entropy` passed!")

Tensor(array([[ 0.0e+00, -1.0e+02, -1.0e+02, -1.0e+02],
       [ 2.5e-01,  2.5e-01,  2.5e-01,  2.5e-01],
       [ 0.0e+00,  0.0e+00,  0.0e+00, -1.0e+02],
       [ 1.0e+03,  0.0e+00,  0.0e+00,  0.0e+00]], dtype=float32), requires_grad=False) slice(None, 3, None)
Tensor(array([0, 0, 0, 0]), requires_grad=False) slice(None, 3, None)
Tensor(array([[   0.       , -100.       , -100.       , -100.       ],
       [  -1.3862944,   -1.3862944,   -1.3862944,   -1.3862944],
       [  -1.0986123,   -1.0986123,   -1.0986123, -101.09861  ]],
      dtype=float32), requires_grad=False) (Tensor(array([0, 1, 2]), requires_grad=False), Tensor(array([0, 0, 0]), requires_grad=False))
Tensor(array([0.       , 1.3862944, 1.0986123, 0.       ], dtype=float32), requires_grad=False) slice(None, 3, None)
All tests in `test_cross_entropy` passed!


<details>
<summary>Help - I'm not sure how to get logprobs from logits.</summary>

They're equal up to a constant: `logprobs = logits - log(sum(exp(logits)))` (where the sum is over the last dimension).

To see why this is true: let's define `C = logits - logprobs`. We know `sum(exp(logits - C)) = sum(exp(logprobs)) = 1` (this is by definition of `logprobs`). Factoring out the `exp(-C)` term, we get `exp(C) = sum(exp(logits))`, hence `C = log(sum(exp(logits)))` as required.

</details>

<details>
<summary>Solution</summary>

```python
def cross_entropy(logits: Tensor, true_labels: Tensor) -> Tensor:
    """Like torch.nn.functional.cross_entropy with reduction='none'.

    logits: shape (batch, classes)
    true_labels: shape (batch,). Each element is the index of the correct label in the logits.

    Return: shape (batch, ) containing the per-example loss.
    """
    batch_size = logits.shape[0]
    logprobs = logits - logits.exp().sum(-1, keepdim=True).log()
    return -logprobs[arange(0, batch_size), true_labels]
```

or alternatively we can solve a slightly different way, which is still equivalent:

```python
true = logits[arange(0, batch_size), true_labels]
return -log(exp(true) / exp(logits).sum(1))
```

</details>

## `NoGrad` context manager

The last thing our backpropagation system needs is the ability to turn it off completely like `torch.no_grad` (or `torch.inference_mode`). We've given you an implementation below, which works by modifying the global `grad_tracking_enabled` variable. 

A few notes on the actual python here (again for people who are more familiar with Python and understand it can skip this):

- The `global` keyword is required in order to modify the global `grad_tracking_enabled` variable. We can still reference its value without this keyword, but we wouldn't be able to change it.
- The special `__enter__` and `__exit__` methods are part of the protocol for context managers, which is a more pythonic way of doing this kind of thing. If we have a context manager block like `with NoGrad(): ...`, then we'll run `NoGrad().__enter__()` before any of the code in this block, and `NoGrad().__exit__()` after the block finishes.

In [35]:
class NoGrad:
    """Context manager that disables grad inside the block. Like torch.no_grad."""

    was_enabled: bool

    def __enter__(self):
        """
        Method which is called whenever the context manager is entered, i.e. at the start of the
        `with NoGrad():` block. This disables gradient tracking (but stores the value it had before,
        so we can set it back to this on exit).
        """
        global grad_tracking_enabled
        self.was_enabled = grad_tracking_enabled
        grad_tracking_enabled = False

    def __exit__(self, type, value, traceback):
        """
        Method which is called whenever we exit the context manager. This sets the global
        `grad_tracking_enabled` variable back to the value it had before we entered the context
        manager.
        """
        global grad_tracking_enabled
        grad_tracking_enabled = self.was_enabled


assert grad_tracking_enabled
with NoGrad():
    assert not grad_tracking_enabled
assert grad_tracking_enabled
print(
    "Verified that we've disabled gradients inside `NoGrad`, then set back to its previous "
    "value once we exit."
)

Verified that we've disabled gradients inside `NoGrad`, then set back to its previous value once we exit.


### Exercise - implement `SGD`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵🔵⚪⚪
> 
> You should spend up to 10-20 minutes on this exercise.
> ```

In today's final exercise, you should implement the `SGD` class methods `zero_grad` and `step`. This should be pretty familiar if you've gone through yesterday's exercises on optimizers (although without all the bells and whistles from those exercises, because we're literally just implementing plain SGD with no momentum, weight decay or anything).

Important note - in yesterday's exercises it was important to use inplace operations, so we would actually modify the existing tensor data rather than creating new tensors, and this is also the case here. The inplace operation `+=` is supported, since under the hood this calls `__iadd__` which we've defined in our `Tensor` class (same for subtraction, the underlying method here is `__isub__`). Note that we did discuss earlier how inplace operations are very risky for backprop, this is generally true however here we're using it for parameter updates which aren't meant to be differentiated and which are performed just before zeroing all gradients - this makes it safe in this particular context.

In [36]:
class SGD:
    def __init__(self, params: Iterable[Parameter], lr: float):
        """Vanilla SGD with no additional features."""
        self.params = list(params)
        self.lr = lr
        self.b = [None for _ in self.params]

    def zero_grad(self) -> None:
        """Iterates through params, and sets all grads to None."""
        for p in self.params:
            p.grad = None

    def step(self) -> None:
        """Iterates through params, and updates each of them by subtracting `param.grad * lr`."""
        with NoGrad():
            for p in self.params:
                assert p.grad is not None, f"p.requires_grad={p.requires_grad}"
                p -= p.grad * self.lr


tests.test_sgd(Parameter, Tensor, SGD)

All tests for `SGD` passed!


<details><summary>Solution</summary>

```python
class SGD:
    def __init__(self, params: Iterable[Parameter], lr: float):
        """Vanilla SGD with no additional features."""
        self.params = list(params)
        self.lr = lr
        self.b = [None for _ in self.params]

    def zero_grad(self) -> None:
        """Iterates through params, and sets all grads to None."""
        for p in self.params:
            p.grad = None

    def step(self) -> None:
        """Iterates through params, and updates each of them by subtracting `param.grad * lr`."""
        with NoGrad():
            for p in self.params:
                p -= p.grad * self.lr
```
</details>

## Training Your Network

We've already looked at data loading and training loops earlier in the course, so we'll provide a minimal version of these today as well as the data loading code.

In [37]:
train_loader, test_loader = get_mnist()
visualize(train_loader)

Preprocessing data...


Training data: 100%|██████████| 60000/60000 [00:01<00:00, 43113.02it/s]
Test data: 100%|██████████| 10000/10000 [00:00<00:00, 45341.18it/s]


To finish the day, below is some code for a training/testing loop for MNIST images, which also logs & plots the results.

Note, it's normal to encounter some bugs and glitches at this point - just go back and fix them until everything runs! Because backprop is annoying and fiddly and depends heavily on exactly how the implementation works (with too many edge cases to test all of them), you may have to resort to replacing your code with the reference solution until you find the source of the error - this is a bit frustrating, but we'd be lying if we said ML isn't without its share of slow debugging sessions!

In [38]:
def train(
    model: MLP,
    train_loader: DataLoader,
    optimizer: SGD,
    epoch: int,
    train_loss_list: list | None = None,
):
    print(f"Epoch: {epoch}")
    progress_bar = tqdm(train_loader)
    for data, target in progress_bar:
        data, target = Tensor(data.numpy()), Tensor(target.numpy())
        optimizer.zero_grad()
        output = model(data)
        loss = cross_entropy(output, target).sum() / len(output)
        loss.backward()
        progress_bar.set_description(f"Train set: Avg loss: {loss.item():.3f}")
        optimizer.step()
        if train_loss_list is not None:
            train_loss_list.append(loss.item())


def test(model: MLP, test_loader: DataLoader, test_accuracy_list: list | None = None):
    test_loss = 0
    test_accuracy = 0
    with NoGrad():
        for data, target in test_loader:
            data, target = Tensor(data.numpy()), Tensor(target.numpy())
            output: Tensor = model(data)
            test_loss += cross_entropy(output, target).sum().item()
            pred = output.argmax(dim=1, keepdim=True)
            test_accuracy += (pred == target.reshape(pred.shape)).sum().item()
    n_data = len(test_loader.dataset)
    test_loss /= n_data
    print(
        f"Test set:  Avg loss: {test_loss:.3f}, Accuracy: {test_accuracy}/{n_data} "
        f"({test_accuracy / n_data:.1%})"
    )
    if test_accuracy_list is not None:
        test_accuracy_list.append(test_accuracy / n_data)


num_epochs = 5
model = MLP()
start = time.time()
train_loss_list = []
test_accuracy_list = []
optimizer = SGD(model.parameters(), 0.01)
for epoch in range(num_epochs):
    train(model, train_loader, optimizer, epoch, train_loss_list)
    test(model, test_loader, test_accuracy_list)
    
print(f"\nCompleted in {time.time() - start: .2f}s")

Epoch: 0


Train set: Avg loss: 2.300:   3%|▎         | 4/118 [00:00<00:03, 37.57it/s]

Tensor(array([[-2.299943 , -2.3495545, -2.338394 , ..., -2.1801422, -2.2315998,
        -2.3077364],
       [-2.275294 , -2.289484 , -2.3036883, ..., -2.2109418, -2.253685 ,
        -2.3163168],
       [-2.3089776, -2.338474 , -2.2239766, ..., -2.1487591, -2.293903 ,
        -2.3215642],
       ...,
       [-2.3637326, -2.3944924, -2.3504696, ..., -2.205048 , -2.2136858,
        -2.28388  ],
       [-2.3443809, -2.3639371, -2.2241094, ..., -2.2054224, -2.267664 ,
        -2.2045853],
       [-2.3670921, -2.3956125, -2.3692565, ..., -2.1962612, -2.277049 ,
        -2.4178824]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 2.296:   8%|▊         | 9/118 [00:00<00:02, 43.73it/s]

Tensor(array([[-2.3050165, -2.2568858, -2.2957857, ..., -2.2480857, -2.2457933,
        -2.3244245],
       [-2.37841  , -2.3282928, -2.2713623, ..., -2.2068787, -2.301724 ,
        -2.238103 ],
       [-2.3597612, -2.3843136, -2.315166 , ..., -2.1600673, -2.2905123,
        -2.339905 ],
       ...,
       [-2.3556995, -2.3637097, -2.360222 , ..., -2.2052279, -2.2100444,
        -2.2717252],
       [-2.2805045, -2.3686   , -2.2787292, ..., -2.1966271, -2.2514064,
        -2.2554471],
       [-2.2920728, -2.3066509, -2.2873821, ..., -2.2492104, -2.270681 ,
        -2.2761354]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 2.281:  12%|█▏        | 14/118 [00:00<00:02, 36.26it/s]

Tensor(array([[-2.312201 , -2.304058 , -2.2667263, ..., -2.2271614, -2.334898 ,
        -2.274638 ],
       [-2.3892224, -2.3215191, -2.2535841, ..., -2.2363985, -2.316127 ,
        -2.3442738],
       [-2.4059718, -2.4215932, -2.3620915, ..., -2.1892846, -2.372299 ,
        -2.3229866],
       ...,
       [-2.3428597, -2.2911632, -2.3596756, ..., -2.1926436, -2.1958978,
        -2.3589716],
       [-2.3262503, -2.2726514, -2.3071506, ..., -2.291366 , -2.2884653,
        -2.3210502],
       [-2.3710465, -2.3198164, -2.3171716, ..., -2.2250135, -2.3256407,
        -2.2423136]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 2.278:  12%|█▏        | 14/118 [00:00<00:02, 36.26it/s]

Tensor(array([[-2.3653545, -2.3160634, -2.3039765, ..., -2.2115786, -2.2908967,
        -2.2477896],
       [-2.3359458, -2.2893662, -2.2184289, ..., -2.1775799, -2.405689 ,
        -2.333919 ],
       [-2.3580797, -2.3174157, -2.2479715, ..., -2.1759412, -2.2741647,
        -2.2149055],
       ...,
       [-2.3877478, -2.322762 , -2.2976708, ..., -2.193311 , -2.1977975,
        -2.2545745],
       [-2.3433015, -2.3279133, -2.3195307, ..., -2.2007496, -2.245069 ,
        -2.2749546],
       [-2.3403354, -2.3533182, -2.3077786, ..., -2.2141407, -2.197729 ,
        -2.3057346]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 2.267:  22%|██▏       | 26/118 [00:00<00:02, 45.54it/s]

Tensor(array([[-2.3624363, -2.337148 , -2.3492095, ..., -2.19095  , -2.2743735,
        -2.2911239],
       [-2.320666 , -2.3347018, -2.2803457, ..., -2.1678464, -2.269194 ,
        -2.223154 ],
       [-2.360737 , -2.263852 , -2.3217049, ..., -2.219971 , -2.1883774,
        -2.2060106],
       ...,
       [-2.3391747, -2.2922525, -2.2349305, ..., -2.260922 , -2.2706132,
        -2.1971817],
       [-2.357501 , -2.235905 , -2.3394244, ..., -2.2211173, -2.2963557,
        -2.4252827],
       [-2.3387196, -2.3320503, -2.3565998, ..., -2.1914911, -2.302931 ,
        -2.3841517]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 2.233:  34%|███▍      | 40/118 [00:00<00:01, 58.13it/s]

Tensor(array([[-2.3232   , -2.2917595, -2.355635 , ..., -2.1522424, -2.189696 ,
        -2.32972  ],
       [-2.3967004, -2.2548673, -2.294546 , ..., -2.197011 , -2.1850667,
        -2.255207 ],
       [-2.3821747, -2.3431237, -2.376081 , ..., -2.2200873, -2.2520838,
        -2.3599324],
       ...,
       [-2.3322802, -2.2892628, -2.294621 , ..., -2.2580705, -2.3315785,
        -2.3669195],
       [-2.3750634, -2.3101423, -2.3457382, ..., -2.185086 , -2.2525396,
        -2.4101768],
       [-2.3902965, -2.3589392, -2.3367503, ..., -2.1768837, -2.2231355,
        -2.2328181]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 2.216:  45%|████▍     | 53/118 [00:01<00:01, 57.73it/s]

Tensor(array([[-2.3200636, -2.327254 , -2.2922993, ..., -2.2707362, -2.2843916,
        -2.2888906],
       [-2.315177 , -2.3324564, -2.243495 , ..., -2.1506176, -2.318692 ,
        -2.2581575],
       [-2.2653394, -2.3355467, -2.2950573, ..., -2.2105067, -2.3441432,
        -2.2578564],
       ...,
       [-2.31164  , -2.4604053, -2.2857435, ..., -2.0446503, -2.4130862,
        -2.1988053],
       [-2.376706 , -2.2577891, -2.3098893, ..., -2.2318954, -2.2473104,
        -2.3679209],
       [-2.415385 , -2.4690447, -2.3837812, ..., -2.04557  , -2.4142191,
        -2.2410767]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 2.161:  56%|█████▌    | 66/118 [00:01<00:00, 59.54it/s]

Tensor(array([[-2.3487165, -2.4184246, -2.2424383, ..., -2.2014024, -2.4577684,
        -2.159934 ],
       [-2.5069203, -2.4054036, -2.2271783, ..., -2.0848403, -2.4072683,
        -2.0413547],
       [-2.3257754, -2.3496685, -2.0939114, ..., -2.1910896, -2.5424678,
        -2.146802 ],
       ...,
       [-2.2856205, -2.389841 , -2.3064983, ..., -2.1912854, -2.3115702,
        -2.3051455],
       [-2.2708292, -2.3611715, -2.2772617, ..., -2.180787 , -2.312296 ,
        -2.1941762],
       [-2.2040799, -2.473943 , -2.4239838, ..., -2.2016132, -2.3565552,
        -2.3059795]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 2.152:  62%|██████▏   | 73/118 [00:01<00:00, 51.71it/s]

Tensor(array([[-2.4389608, -2.3466387, -2.2028022, ..., -2.1335695, -2.5168662,
        -2.276178 ],
       [-2.324962 , -2.3958812, -2.4089649, ..., -2.2339616, -2.2023838,
        -2.3083553],
       [-2.2987022, -2.59615  , -2.4223218, ..., -1.9587305, -2.509765 ,
        -2.2602801],
       ...,
       [-2.459984 , -2.37495  , -2.282394 , ..., -2.1283364, -2.3698502,
        -2.3589628],
       [-2.4283233, -2.3310084, -2.302261 , ..., -2.1465847, -2.2776644,
        -2.364072 ],
       [-2.4165802, -2.516431 , -2.2785864, ..., -2.021519 , -2.4327595,
        -2.031805 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 2.117:  71%|███████   | 84/118 [00:01<00:00, 44.91it/s]

Tensor(array([[-2.5631104, -2.6462843, -2.2055726, ..., -1.9397755, -2.5711799,
        -1.9302948],
       [-2.506017 , -2.0174687, -2.3438866, ..., -2.237675 , -2.2103715,
        -2.2473087],
       [-2.4334574, -2.4662619, -2.104207 , ..., -2.0949159, -2.502983 ,
        -2.1627078],
       ...,
       [-2.5301569, -2.4941611, -2.2770271, ..., -2.100385 , -2.485994 ,
        -2.1016505],
       [-2.155989 , -2.456372 , -2.3899748, ..., -2.2378848, -2.4033217,
        -2.232726 ],
       [-2.3870013, -2.5920427, -2.4853923, ..., -1.9179978, -2.422211 ,
        -2.0525012]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 2.090:  76%|███████▋  | 90/118 [00:01<00:00, 48.15it/s]

Tensor(array([[-2.3825583, -2.4901557, -2.3253815, ..., -2.1571147, -2.468136 ,
        -2.2371006],
       [-2.296629 , -2.3568256, -2.4715006, ..., -2.2007732, -2.313177 ,
        -2.524482 ],
       [-2.5719314, -2.491916 , -2.3045642, ..., -2.0866394, -2.3989413,
        -1.9774468],
       ...,
       [-2.5437846, -2.221269 , -2.214774 , ..., -2.0888345, -2.4334724,
        -2.1959145],
       [-2.4252744, -2.646911 , -2.5393314, ..., -1.9836957, -2.3070533,
        -2.0829113],
       [-2.5582917, -2.6117   , -2.373908 , ..., -1.8346231, -2.6050005,
        -1.9040341]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 2.047:  85%|████████▍ | 100/118 [00:02<00:00, 39.54it/s]

Tensor(array([[-2.6455185, -2.8418832, -2.386767 , ..., -1.935416 , -2.6719606,
        -1.7318481],
       [-2.4741683, -2.1315286, -2.3289194, ..., -2.2467616, -2.2115598,
        -2.1735592],
       [-2.2968638, -2.3293188, -2.3531578, ..., -2.1986876, -2.2174528,
        -2.2391527],
       ...,
       [-2.66033  , -2.6684518, -2.2636795, ..., -1.899343 , -2.5990741,
        -2.0753026],
       [-2.226761 , -2.4739342, -2.4672797, ..., -2.2681537, -2.1806884,
        -2.3821433],
       [-2.3878183, -2.3544455, -2.4205837, ..., -2.1083956, -2.249688 ,
        -2.0371852]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 2.006:  96%|█████████▌| 113/118 [00:02<00:00, 49.33it/s]

Tensor(array([[-2.505391 , -2.3634777, -2.1629758, ..., -2.053172 , -2.638293 ,
        -2.1441953],
       [-2.8665183, -2.5830247, -2.3367226, ..., -1.9566536, -2.5173204,
        -2.017976 ],
       [-2.4573817, -2.5588033, -2.3804567, ..., -1.9830917, -2.5687358,
        -2.106643 ],
       ...,
       [-2.5709696, -2.2812483, -2.243996 , ..., -2.184417 , -2.5529234,
        -2.2696724],
       [-2.3700325, -2.2691271, -2.4141827, ..., -2.3118858, -2.2099097,
        -2.459553 ],
       [-2.1175933, -2.6047597, -2.2959964, ..., -2.1236026, -2.585979 ,
        -2.2404768]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.968: 100%|██████████| 118/118 [00:02<00:00, 48.05it/s]


Tensor(array([[-2.9619682, -2.8141966, -2.5480766, ..., -1.8460906, -2.7379527,
        -1.5908865],
       [-2.5055249, -2.2232218, -2.4573193, ..., -2.1562068, -2.3514698,
        -2.2318408],
       [-2.463581 , -2.1458805, -2.1468656, ..., -2.315143 , -2.3865843,
        -2.4534082],
       ...,
       [-1.9867465, -2.6602464, -2.3979425, ..., -2.09348  , -2.4092095,
        -2.1521046],
       [-3.0067875, -3.0453362, -2.4939303, ..., -1.9284754, -2.738695 ,
        -1.6588728],
       [-2.3392022, -2.550208 , -2.5266504, ..., -2.2933488, -2.1192598,
        -2.5102053]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.983:   0%|          | 0/118 [00:00<?, ?it/s]

Tensor(array([[-2.2095397, -2.727636 , -2.51737  , ..., -2.0516884, -2.2321699,
        -2.0885696],
       [-3.0305188, -3.2470026, -2.5491724, ..., -1.8477874, -2.9267092,
        -1.6310962],
       [-1.717649 , -2.9208908, -2.4652205, ..., -2.0759974, -2.7893243,
        -2.1144574],
       ...,
       [-2.1880412, -2.4978428, -2.2764735, ..., -2.09779  , -2.5620577,
        -2.2626858],
       [-2.412918 , -2.287861 , -2.4271505, ..., -2.2395797, -2.2730763,
        -2.2007437],
       [-2.434734 , -3.4137676, -2.5943313, ..., -1.7412131, -2.9111986,
        -1.5983222]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.974:   3%|▎         | 4/118 [00:00<00:05, 22.00it/s]

Tensor(array([[-2.846951 , -3.1858745, -2.402026 , ..., -1.877578 , -2.9468126,
        -1.7000263],
       [-2.4790492, -3.000021 , -2.4735608, ..., -1.9291207, -2.7157514,
        -1.6023297],
       [-2.4268022, -2.1585271, -2.4425082, ..., -2.2605395, -1.9788043,
        -2.3595498],
       ...,
       [-2.8903089, -2.6646886, -2.5202632, ..., -1.904963 , -2.6102107,
        -1.6700891],
       [-2.6204545, -2.2405381, -2.2467093, ..., -2.3732548, -2.26891  ,
        -2.544071 ],
       [-2.5661914, -1.8001883, -2.3393772, ..., -2.3812225, -2.1673245,
        -2.5605073]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.942:   6%|▌         | 7/118 [00:00<00:04, 22.33it/s]

Tensor(array([[-2.6392612, -1.7881802, -2.4146802, ..., -2.3114028, -2.1967325,
        -2.302933 ],
       [-2.5857902, -2.0771408, -2.3776596, ..., -2.2552118, -2.25808  ,
        -2.2248583],
       [-2.3775918, -2.4037168, -2.5055099, ..., -2.1797736, -2.101569 ,
        -2.419545 ],
       ...,
       [-2.3729758, -2.4760683, -2.400548 , ..., -2.3083591, -2.3054929,
        -2.2512057],
       [-2.7488618, -2.4100747, -2.5135016, ..., -2.0200608, -2.3525372,
        -1.885134 ],
       [-2.3783076, -2.3953495, -2.4786339, ..., -2.3725247, -2.2020273,
        -2.6475952]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.926:   9%|▉         | 11/118 [00:00<00:04, 25.40it/s]

Tensor(array([[-2.3458736, -2.3296556, -2.4150283, ..., -2.2775269, -2.234295 ,
        -2.396422 ],
       [-1.9211634, -2.4826818, -2.4657729, ..., -2.2695663, -2.5067117,
        -2.4343042],
       [-2.697687 , -2.7448535, -2.506172 , ..., -1.6789932, -2.646335 ,
        -1.8277558],
       ...,
       [-2.6030617, -2.3025842, -2.5453374, ..., -2.0618703, -2.2735262,
        -1.9502879],
       [-2.5803356, -2.113031 , -2.3907025, ..., -2.3836257, -2.192925 ,
        -2.5418026],
       [-2.694533 , -1.6886373, -2.3755198, ..., -2.3088834, -2.1560934,
        -2.5227005]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.905:  12%|█▏        | 14/118 [00:00<00:04, 24.50it/s]

Tensor(array([[-2.9839854, -3.0454347, -2.393516 , ..., -1.7909513, -2.8953433,
        -1.7871752],
       [-2.5443625, -2.2907062, -1.7557523, ..., -2.381324 , -2.775705 ,
        -2.293873 ],
       [-1.8197722, -2.594662 , -2.355196 , ..., -2.2121637, -2.7968988,
        -2.2758703],
       ...,
       [-2.4807076, -2.408163 , -2.558599 , ..., -1.7990334, -2.3237114,
        -2.0389924],
       [-2.4043434, -2.3552985, -2.624523 , ..., -2.2496302, -2.0166156,
        -2.521582 ],
       [-2.69469  , -2.1582868, -2.4813998, ..., -2.0832014, -2.2548845,
        -1.977052 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.867:  14%|█▍        | 17/118 [00:00<00:04, 23.01it/s]

Tensor(array([[-2.605755 , -2.419906 , -2.331732 , ..., -2.135491 , -2.415816 ,
        -2.1655598],
       [-2.441124 , -2.405725 , -2.2039177, ..., -2.3535757, -2.590689 ,
        -2.4167871],
       [-2.6494522, -2.2039647, -2.265144 , ..., -1.9512635, -2.476628 ,
        -2.112609 ],
       ...,
       [-2.7783394, -1.6820922, -2.456929 , ..., -2.3037267, -2.2163856,
        -2.2817008],
       [-2.4240565, -2.831327 , -2.7930346, ..., -1.6817787, -2.3515544,
        -1.9493002],
       [-2.4996142, -2.745727 , -2.8204105, ..., -1.7652828, -2.416961 ,
        -1.9175316]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.868:  19%|█▉        | 23/118 [00:00<00:03, 31.29it/s]

Tensor(array([[-1.2655674, -3.3562045, -2.7072043, ..., -2.156733 , -2.7989607,
        -2.1776676],
       [-2.853347 , -2.4606407, -2.4219637, ..., -1.8036401, -2.5762587,
        -1.7918339],
       [-2.7261283, -1.5460012, -2.356927 , ..., -2.4753592, -2.2028694,
        -2.6352143],
       ...,
       [-1.6163831, -2.839177 , -2.3457658, ..., -2.2618773, -2.5740774,
        -2.2338152],
       [-2.2706578, -2.5786338, -2.4801955, ..., -2.379891 , -2.20061  ,
        -2.2722213],
       [-1.3201059, -3.053532 , -2.505878 , ..., -2.1176202, -2.9536874,
        -2.2050257]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.825:  23%|██▎       | 27/118 [00:01<00:03, 25.32it/s]

Tensor(array([[-2.8438923, -1.5771369, -2.4256134, ..., -2.3841019, -1.9780391,
        -2.6351209],
       [-2.3365011, -2.4076748, -1.7577055, ..., -2.4013207, -2.8820426,
        -2.5031173],
       [-2.822098 , -2.0820365, -2.1498778, ..., -2.331893 , -2.458351 ,
        -2.3790154],
       ...,
       [-2.5536063, -2.1151693, -2.5215364, ..., -2.1703463, -2.2581103,
        -2.0583067],
       [-2.5739374, -2.69183  , -2.605423 , ..., -1.9706146, -2.5007653,
        -1.5854286],
       [-2.746528 , -2.222636 , -2.191363 , ..., -2.1618202, -2.7498283,
        -2.2941868]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.822:  23%|██▎       | 27/118 [00:01<00:03, 25.32it/s]

Tensor(array([[-1.3380864, -3.135345 , -2.422444 , ..., -2.1484883, -2.9315286,
        -2.3174021],
       [-2.3751874, -2.3345318, -2.5938907, ..., -2.3239675, -2.2185018,
        -2.5581708],
       [-2.5289783, -2.0628946, -2.3490517, ..., -2.1655521, -2.5804377,
        -2.2905781],
       ...,
       [-2.8440812, -2.8365352, -2.3905168, ..., -1.871445 , -2.8105688,
        -1.6525794],
       [-2.2693112, -2.122106 , -2.260949 , ..., -2.5191903, -2.2943912,
        -2.7760353],
       [-2.7545414, -2.7079985, -2.5593104, ..., -1.970612 , -2.529345 ,
        -1.8013778]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.788:  25%|██▌       | 30/118 [00:01<00:03, 25.19it/s]

Tensor(array([[-2.2846444, -1.9408586, -2.0964432, ..., -2.4760795, -2.5659134,
        -2.8819776],
       [-2.67678  , -2.6571465, -2.4633136, ..., -1.9881283, -2.7271106,
        -1.7334867],
       [-2.8776941, -1.4506419, -2.5346313, ..., -2.3767664, -2.0759373,
        -2.469423 ],
       ...,
       [-2.6785846, -1.5770962, -2.2050393, ..., -2.4001403, -2.2581587,
        -2.490985 ],
       [-2.788381 , -1.4465418, -2.3771777, ..., -2.519127 , -2.171795 ,
        -2.7182295],
       [-2.824955 , -1.4604213, -2.3648758, ..., -2.4746401, -2.1077924,
        -2.6968296]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.792:  28%|██▊       | 33/118 [00:01<00:03, 25.81it/s]

Tensor(array([[-2.205668 , -2.5749235, -2.4806428, ..., -2.1793704, -2.2670286,
        -2.3949661],
       [-2.7043047, -1.9418877, -2.4794512, ..., -2.2120268, -2.1660688,
        -1.9957908],
       [-2.548071 , -2.2737124, -1.9094156, ..., -2.5246482, -2.517439 ,
        -2.4083827],
       ...,
       [-1.4943817, -3.0010858, -2.6071022, ..., -2.1172442, -2.8249745,
        -2.1297967],
       [-2.4594657, -1.8812494, -1.9228351, ..., -2.4476464, -2.4704022,
        -2.5307617],
       [-2.5698016, -2.2339306, -2.419082 , ..., -2.4582975, -1.9230499,
        -2.492185 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.766:  31%|███       | 36/118 [00:01<00:04, 19.46it/s]

Tensor(array([[-2.6875892, -1.97313  , -2.5322738, ..., -1.9799876, -2.1307046,
        -2.2492478],
       [-2.2341256, -2.7080033, -2.3099866, ..., -2.0761976, -2.380994 ,
        -2.1768408],
       [-2.421687 , -1.9726559, -2.4078524, ..., -2.392645 , -2.1317348,
        -2.2327852],
       ...,
       [-2.7397106, -1.5196304, -2.2620203, ..., -2.5164695, -2.1428623,
        -2.7855098],
       [-2.870886 , -1.5439754, -2.3666809, ..., -2.3771718, -2.3247826,
        -2.3918855],
       [-2.1265874, -2.4990017, -2.2800498, ..., -2.5137367, -2.2038672,
        -2.6869874]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.722:  35%|███▍      | 41/118 [00:01<00:03, 25.11it/s]

Tensor(array([[-1.7287178, -2.4552457, -2.2750432, ..., -2.425429 , -2.5687094,
        -2.8619883],
       [-2.722017 , -3.2176003, -2.6691306, ..., -1.3308035, -3.0575562,
        -1.6970077],
       [-2.3303094, -2.3284016, -2.2688184, ..., -2.303647 , -2.096136 ,
        -2.4576797],
       ...,
       [-2.912418 , -2.4432511, -2.5244036, ..., -1.8651373, -2.507252 ,
        -1.773168 ],
       [-2.5086277, -2.4523466, -2.7290883, ..., -1.903792 , -1.972062 ,
        -2.159288 ],
       [-1.350281 , -2.9340692, -2.5831034, ..., -2.3254304, -2.6109936,
        -2.4429162]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.722:  37%|███▋      | 44/118 [00:01<00:03, 24.59it/s]

Tensor(array([[-1.8925991, -2.775748 , -2.1013653, ..., -2.0357807, -2.588771 ,
        -2.4831188],
       [-3.172965 , -3.0793931, -2.6244562, ..., -1.8475403, -2.811211 ,
        -1.4962224],
       [-2.5600986, -2.629491 , -2.6626225, ..., -1.9154309, -2.1792822,
        -2.0666418],
       ...,
       [-2.1118174, -2.4881134, -2.5086563, ..., -2.491686 , -2.2997646,
        -2.9223447],
       [-3.1104174, -3.892354 , -2.8547018, ..., -1.6979852, -3.4203024,
        -1.5516052],
       [-2.767143 , -2.3355901, -2.293322 , ..., -2.2184703, -2.7193496,
        -2.224301 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.682:  40%|███▉      | 47/118 [00:01<00:02, 25.65it/s]

Tensor(array([[-2.9583983, -1.3619566, -2.5312545, ..., -2.4691699, -2.0959702,
        -2.5353367],
       [-3.230341 , -2.8694782, -2.577442 , ..., -1.9531708, -2.6435251,
        -1.6550229],
       [-2.0555832, -3.025815 , -2.152255 , ..., -2.0125592, -3.3124511,
        -2.2048302],
       ...,
       [-2.749366 , -2.5302951, -2.1162899, ..., -2.3585389, -2.9439156,
        -2.545025 ],
       [-2.9680564, -3.5422997, -2.4133022, ..., -1.7011962, -3.5315049,
        -1.4888289],
       [-1.8950319, -3.0585318, -2.8020911, ..., -2.0745716, -2.3638427,
        -2.3755515]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.666:  40%|███▉      | 47/118 [00:02<00:02, 25.65it/s]

Tensor(array([[-2.9188905, -1.3691628, -2.2084687, ..., -2.5015934, -2.205812 ,
        -2.579547 ],
       [-2.9231143, -1.9731282, -2.1893382, ..., -2.20967  , -2.1473665,
        -2.2317386],
       [-1.2837251, -3.1236613, -2.472136 , ..., -2.3679817, -2.9918015,
        -2.5448287],
       ...,
       [-2.7844884, -2.5744643, -2.7331777, ..., -1.6531823, -2.6601307,
        -1.6973572],
       [-2.9966025, -3.4099586, -2.5777752, ..., -1.9032674, -3.115768 ,
        -1.6846337],
       [-2.473844 , -1.8187141, -2.164731 , ..., -2.563321 , -2.2022998,
        -2.7807288]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.611:  47%|████▋     | 55/118 [00:02<00:02, 31.03it/s]

Tensor(array([[-3.037291 , -1.333916 , -2.480025 , ..., -2.3514507, -2.1845152,
        -2.453133 ],
       [-2.4744296, -3.0390906, -2.698268 , ..., -1.3862457, -2.843886 ,
        -1.8474   ],
       [-1.952326 , -2.7790396, -2.4516592, ..., -2.0127723, -2.7221491,
        -2.535397 ],
       ...,
       [-1.3183699, -3.1513638, -2.696148 , ..., -2.3218   , -2.51798  ,
        -2.8787615],
       [-2.3002236, -2.232431 , -2.6475236, ..., -2.3193948, -1.8781127,
        -2.4635766],
       [-3.4745393, -2.554042 , -2.2068615, ..., -2.2089128, -2.7251885,
        -1.7662069]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.603:  47%|████▋     | 55/118 [00:02<00:02, 31.03it/s]

Tensor(array([[-3.2202668 , -1.0882194 , -2.3223288 , ..., -2.687296  ,
        -2.2213135 , -2.952311  ],
       [-2.196684  , -2.824517  , -2.619928  , ..., -2.3917773 ,
        -2.1169274 , -2.0166862 ],
       [-0.70620584, -3.894187  , -2.913788  , ..., -2.447081  ,
        -3.2760968 , -2.545462  ],
       ...,
       [-3.2650013 , -1.2125266 , -2.4350557 , ..., -2.4248104 ,
        -2.207712  , -2.5783458 ],
       [-2.1505375 , -2.1525064 , -2.0103621 , ..., -2.2500196 ,
        -2.4309773 , -2.6779237 ],
       [-3.1804743 , -1.7153896 , -2.0172114 , ..., -2.9541473 ,
        -2.162487  , -2.9370422 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 1.566:  56%|█████▌    | 66/118 [00:02<00:01, 35.69it/s]

Tensor(array([[-3.2277179, -3.8892303, -3.201854 , ..., -1.2110449, -3.3339472,
        -1.4328389],
       [-2.7781165, -2.6665928, -2.2909656, ..., -1.9460202, -2.3332806,
        -1.9855703],
       [-2.7519948, -3.9417474, -3.1741   , ..., -1.2944905, -3.3132155,
        -1.4555862],
       ...,
       [-3.5018363, -4.3949103, -3.0815966, ..., -1.7683737, -3.6226106,
        -1.479594 ],
       [-1.8895315, -3.9408846, -2.9469662, ..., -1.3744533, -3.3494906,
        -1.7056782],
       [-2.1631746, -2.8314838, -2.0268998, ..., -2.7832744, -2.727976 ,
        -2.9518795]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.541:  56%|█████▌    | 66/118 [00:02<00:01, 35.69it/s]

Tensor(array([[-2.761709 , -2.5065958, -1.5766177, ..., -2.5268998, -2.7133007,
        -2.0464664],
       [-3.1585858, -1.224575 , -2.6361375, ..., -2.4665432, -2.0699773,
        -2.6225305],
       [-2.6219199, -2.5142589, -1.5294012, ..., -2.7081022, -2.6972272,
        -2.4859478],
       ...,
       [-3.3366451, -3.5098712, -2.0577862, ..., -2.0029202, -3.401659 ,
        -1.8084466],
       [-3.30834  , -2.0106957, -2.710371 , ..., -1.9675075, -1.9096954,
        -2.1150436],
       [-2.0863411, -2.2715733, -2.100021 , ..., -2.698976 , -2.6313198,
        -3.2887607]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.467:  64%|██████▍   | 76/118 [00:02<00:01, 39.51it/s]

Tensor(array([[-1.1064439, -3.159486 , -2.7151194, ..., -2.4211845, -2.7602015,
        -2.7913709],
       [-2.348711 , -2.664597 , -2.1412773, ..., -2.4464905, -2.7628062,
        -2.5098035],
       [-3.1712728, -3.9950218, -2.963363 , ..., -1.5573565, -3.299128 ,
        -1.2428479],
       ...,
       [-2.9893913, -1.989069 , -2.137719 , ..., -2.4765208, -2.6801775,
        -2.5095508],
       [-2.5764277, -2.0221703, -2.326551 , ..., -2.2794998, -2.3738408,
        -2.600718 ],
       [-3.1868217, -4.0810347, -3.3840342, ..., -1.0347549, -3.4060109,
        -1.5211682]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.430:  64%|██████▍   | 76/118 [00:02<00:01, 39.51it/s]

Tensor(array([[-3.4466896 , -2.4421637 , -2.8974843 , ..., -2.185617  ,
        -2.048658  , -1.7350446 ],
       [-2.2179196 , -3.0317183 , -2.7651553 , ..., -2.600887  ,
        -1.9568802 , -2.9038951 ],
       [-1.9699438 , -2.5694091 , -1.625355  , ..., -2.6761289 ,
        -2.8769073 , -2.5006373 ],
       ...,
       [-3.4626021 , -0.9179158 , -2.5978086 , ..., -3.0554113 ,
        -1.9829267 , -3.4399354 ],
       [-2.7508872 , -2.737756  , -2.196546  , ..., -2.2087893 ,
        -1.9216845 , -1.8809097 ],
       [-0.92249227, -3.2516618 , -2.9177947 , ..., -2.5638728 ,
        -2.860279  , -2.7831414 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 1.462:  69%|██████▊   | 81/118 [00:02<00:01, 33.48it/s]

Tensor(array([[-2.2514286, -2.8316422, -2.60338  , ..., -3.0929255, -2.3166053,
        -3.530907 ],
       [-3.5685356, -0.826367 , -2.5783484, ..., -3.0234468, -2.2093651,
        -3.3719225],
       [-3.6602669, -0.8481281, -2.5011566, ..., -3.0196626, -2.1499202,
        -3.2902536],
       ...,
       [-2.3709826, -2.9511056, -2.416037 , ..., -2.8874063, -1.8478804,
        -2.6770592],
       [-3.8264527, -3.9640565, -2.9050405, ..., -1.8740711, -3.1885443,
        -1.5484965],
       [-3.20887  , -1.376287 , -2.5288937, ..., -2.26542  , -2.2386708,
        -2.4436724]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.380:  72%|███████▏  | 85/118 [00:02<00:01, 32.93it/s]

Tensor(array([[-3.2192771 , -2.8719752 , -2.796998  , ..., -1.2721623 ,
        -2.8816612 , -1.764666  ],
       [-2.8221383 , -1.8303118 , -1.652677  , ..., -2.7892776 ,
        -2.2149324 , -2.9576128 ],
       [-2.8383105 , -1.6518636 , -2.4526503 , ..., -2.2037938 ,
        -1.7899795 , -2.5707445 ],
       ...,
       [-2.909973  , -1.8900383 , -2.6326482 , ..., -1.877775  ,
        -1.9843898 , -2.2024992 ],
       [-3.3463717 , -1.5889598 , -2.4861586 , ..., -2.227804  ,
        -1.9172394 , -2.3561697 ],
       [-3.687973  , -0.83643055, -2.585539  , ..., -3.07759   ,
        -2.141789  , -3.5474193 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 1.367:  77%|███████▋  | 91/118 [00:03<00:00, 38.69it/s]

Tensor(array([[-1.633833 , -3.8261158, -2.09151  , ..., -2.6436057, -3.3415172,
        -2.7866864],
       [-3.0840607, -3.5193322, -3.2430046, ..., -1.0484061, -2.7556822,
        -1.5316094],
       [-3.9102292, -4.6524115, -2.899459 , ..., -1.6591326, -3.787484 ,
        -1.5357393],
       ...,
       [-4.031305 , -3.7288327, -2.17019  , ..., -1.9698807, -3.6964753,
        -1.5169948],
       [-1.8947123, -3.8447597, -2.9125016, ..., -2.9124506, -2.1676939,
        -3.3404288],
       [-3.069604 , -2.0503376, -2.4776275, ..., -1.9484692, -1.8604934,
        -2.23523  ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.274:  86%|████████▌ | 101/118 [00:03<00:00, 37.66it/s]

Tensor(array([[-2.6617353, -3.4074981, -2.3931792, ..., -1.9372728, -1.9666407,
        -1.7394428],
       [-2.3763218, -2.7041893, -2.2740965, ..., -2.6288784, -2.237485 ,
        -2.5458562],
       [-1.9619154, -3.1928792, -2.606971 , ..., -3.6640146, -2.5246634,
        -3.906502 ],
       ...,
       [-1.4729136, -2.7984054, -2.0635467, ..., -2.6029737, -2.653088 ,
        -2.973794 ],
       [-2.4698467, -3.903442 , -2.771892 , ..., -2.153146 , -3.0501513,
        -1.8062141],
       [-3.383808 , -2.4933343, -2.8864481, ..., -1.1143856, -2.5634863,
        -1.8408047]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.203:  91%|█████████ | 107/118 [00:03<00:00, 42.02it/s]

Tensor(array([[-4.4660063, -0.5401466, -2.7691085, ..., -3.5305734, -2.2483642,
        -3.896325 ],
       [-2.2616313, -2.44675  , -2.4222512, ..., -2.875173 , -2.232685 ,
        -3.4306343],
       [-1.7508705, -4.100959 , -2.4240828, ..., -2.171553 , -2.4850333,
        -2.3552377],
       ...,
       [-3.5055041, -3.6265213, -2.5353296, ..., -1.8819392, -2.8866265,
        -1.1888362],
       [-3.7684398, -1.6618536, -2.5892503, ..., -2.318569 , -1.7598172,
        -2.2182941],
       [-1.0669526, -3.8993683, -2.1641128, ..., -2.4799743, -2.949008 ,
        -2.861118 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.183: 100%|██████████| 118/118 [00:03<00:00, 31.36it/s]


Tensor(array([[-4.050749  , -1.1966264 , -2.2247925 , ..., -2.5028374 ,
        -2.0640733 , -2.367999  ],
       [-0.40706277, -4.877533  , -3.262052  , ..., -3.4157186 ,
        -3.1534019 , -3.6632977 ],
       [-3.6876314 , -4.3430166 , -3.5794265 , ..., -1.5944617 ,
        -3.1801238 , -1.1113206 ],
       ...,
       [-4.4418583 , -5.307934  , -3.771732  , ..., -1.2252886 ,
        -3.884058  , -1.3821658 ],
       [-4.667625  , -0.5222106 , -2.882075  , ..., -3.4922056 ,
        -2.222714  , -3.7195864 ],
       [-4.2648735 , -0.6258731 , -2.2702057 , ..., -3.2798812 ,
        -2.3306346 , -3.4733653 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 1.179:   2%|▏         | 2/118 [00:00<00:05, 19.58it/s]

Tensor(array([[-4.5259504 , -0.49837828, -2.6733599 , ..., -3.4851327 ,
        -2.478452  , -3.8672395 ],
       [-2.5120845 , -2.091979  , -2.1525736 , ..., -3.560065  ,
        -2.5236948 , -4.083996  ],
       [-2.5141041 , -4.206344  , -1.8412683 , ..., -3.7930746 ,
        -4.549981  , -3.706917  ],
       ...,
       [-4.2627964 , -0.6230693 , -2.815706  , ..., -3.113957  ,
        -2.2110653 , -3.444707  ],
       [-4.27658   , -5.3466587 , -3.451135  , ..., -1.8081127 ,
        -4.143791  , -1.3297397 ],
       [-4.0935583 , -4.8987947 , -3.2363758 , ..., -1.7221569 ,
        -3.4746041 , -1.3481027 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 1.198:   2%|▏         | 2/118 [00:00<00:05, 19.58it/s]

Tensor(array([[-3.6401656 , -4.3731346 , -2.8563538 , ..., -1.6287699 ,
        -3.1650026 , -1.2917532 ],
       [-4.0451303 , -4.536697  , -3.2881916 , ..., -2.0986574 ,
        -2.8372822 , -1.159309  ],
       [-4.6320505 , -0.45512915, -2.7027066 , ..., -3.5730603 ,
        -2.6214385 , -3.9559543 ],
       ...,
       [-4.9814973 , -6.429324  , -4.6323166 , ..., -1.8127764 ,
        -3.851003  , -1.2539024 ],
       [-4.039157  , -4.902532  , -3.9399474 , ..., -0.75057435,
        -3.7255688 , -1.6288531 ],
       [-4.4137764 , -0.6029539 , -2.7183497 , ..., -3.06079   ,
        -2.376963  , -3.4055772 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 1.116:  10%|█         | 12/118 [00:00<00:03, 35.15it/s]

Tensor(array([[-3.1444972, -1.8740814, -2.4920135, ..., -2.6298685, -1.3253392,
        -2.69186  ],
       [-3.4519296, -1.0454216, -2.5319748, ..., -3.6367812, -1.9065442,
        -3.9460106],
       [-2.9243104, -3.3630116, -3.2024314, ..., -4.135602 , -2.194161 ,
        -4.61431  ],
       ...,
       [-4.1288943, -2.7026124, -3.348121 , ..., -1.6082814, -1.6301332,
        -1.7792162],
       [-4.075033 , -1.5942199, -1.7985957, ..., -3.104816 , -2.4774518,
        -3.078072 ],
       [-3.0191681, -3.2533941, -2.9949598, ..., -1.706708 , -1.9414141,
        -1.9463608]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.090:  10%|█         | 12/118 [00:00<00:03, 35.15it/s]

Tensor(array([[-0.5180154 , -5.3121476 , -3.6874318 , ..., -3.3677676 ,
        -2.7045686 , -3.6959255 ],
       [-5.0202503 , -6.40512   , -4.2307925 , ..., -2.0203133 ,
        -4.60082   , -1.4398549 ],
       [-3.677897  , -1.7721488 , -2.2109237 , ..., -3.4805014 ,
        -1.1383109 , -3.2142913 ],
       ...,
       [-0.65218306, -6.1998844 , -4.015107  , ..., -1.5981491 ,
        -4.039813  , -2.3675752 ],
       [-4.0188894 , -2.665565  , -2.9915512 , ..., -2.1073515 ,
        -1.5480679 , -2.1425946 ],
       [-3.043926  , -1.6083    , -2.5822418 , ..., -2.9398608 ,
        -1.5321943 , -3.362295  ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 1.053:  14%|█▎        | 16/118 [00:00<00:03, 32.41it/s]

Tensor(array([[-3.8986726 , -2.9658785 , -2.150009  , ..., -3.2102988 ,
        -1.4460092 , -2.4208424 ],
       [-3.8983548 , -4.893415  , -3.389329  , ..., -1.6479778 ,
        -3.6182423 , -1.500782  ],
       [-5.212821  , -5.949032  , -3.323753  , ..., -2.4135625 ,
        -4.272942  , -1.5430987 ],
       ...,
       [-0.13880491, -6.8122263 , -4.2376986 , ..., -3.86359   ,
        -4.7092433 , -4.1039085 ],
       [-2.2896852 , -4.1195354 , -1.4752419 , ..., -3.0471444 ,
        -3.5113323 , -3.041176  ],
       [-4.645364  , -5.4968214 , -4.799904  , ..., -0.50749683,
        -3.9265733 , -1.5981688 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 1.060:  19%|█▊        | 22/118 [00:00<00:02, 38.45it/s]

Tensor(array([[-4.51824   , -2.473875  , -2.8938527 , ..., -1.8673961 ,
        -1.8965104 , -1.6940472 ],
       [-0.18095875, -6.4902    , -4.8103356 , ..., -3.628929  ,
        -4.3251724 , -4.0139055 ],
       [-4.406041  , -2.7666106 , -1.7185233 , ..., -3.4253623 ,
        -3.3424594 , -3.1808982 ],
       ...,
       [-3.7275991 , -2.0909355 , -2.6687071 , ..., -2.8949208 ,
        -0.92317474, -2.8861754 ],
       [-4.356218  , -6.789473  , -4.957522  , ..., -1.4676194 ,
        -3.7161663 , -0.91826034],
       [-3.569949  , -2.9744766 , -3.117266  , ..., -1.0617676 ,
        -2.6373627 , -1.7375678 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 1.010:  24%|██▎       | 28/118 [00:00<00:02, 42.89it/s]

Tensor(array([[-5.2882667 , -0.46672654, -3.236062  , ..., -3.6488914 ,
        -2.0796812 , -3.76762   ],
       [-2.0430253 , -3.1494062 , -1.0585784 , ..., -2.7807002 ,
        -3.0894969 , -3.3207693 ],
       [-2.218663  , -2.614206  , -1.5431967 , ..., -2.915748  ,
        -1.9499111 , -3.3978207 ],
       ...,
       [-4.1664214 , -2.4937859 , -0.62645745, ..., -4.2781925 ,
        -3.2567236 , -3.533776  ],
       [-4.3254614 , -3.6900225 , -1.9780188 , ..., -4.2845564 ,
        -3.3593946 , -3.505684  ],
       [-2.7809167 , -4.35639   , -3.271898  , ..., -4.281629  ,
        -2.0512762 , -4.320358  ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.984:  24%|██▎       | 28/118 [00:00<00:02, 42.89it/s]

Tensor(array([[-4.6607084 , -1.2194396 , -2.2843657 , ..., -3.1738677 ,
        -1.5157167 , -2.8127708 ],
       [-4.1235976 , -2.232186  , -1.7688837 , ..., -3.1165679 ,
        -2.6155102 , -2.4998925 ],
       [-3.3562422 , -2.2551    , -1.8116511 , ..., -3.679027  ,
        -1.8291637 , -3.4392257 ],
       ...,
       [-5.379054  , -3.8314037 , -3.570809  , ..., -1.5138346 ,
        -3.0730698 , -1.1221193 ],
       [-4.858163  , -0.5949333 , -3.160525  , ..., -3.3366632 ,
        -2.0005548 , -3.5298984 ],
       [-5.5036635 , -4.4632015 , -3.9968762 , ..., -0.90467334,
        -3.4739153 , -1.2268946 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 1.004:  28%|██▊       | 33/118 [00:01<00:02, 34.95it/s]

Tensor(array([[-5.320982  , -0.37960005, -3.0454829 , ..., -3.9994936 ,
        -2.4032447 , -4.3085637 ],
       [-4.1478004 , -3.1161141 , -3.1206934 , ..., -3.5504842 ,
        -0.6458676 , -2.7729576 ],
       [-3.8691382 , -3.039248  , -2.9201846 , ..., -1.9499106 ,
        -2.3082843 , -1.4217694 ],
       ...,
       [-3.918067  , -2.1841102 , -1.1220639 , ..., -3.416765  ,
        -2.2476597 , -3.1575532 ],
       [-2.291126  , -4.5459223 , -2.0508142 , ..., -2.376198  ,
        -3.527382  , -2.3635826 ],
       [-1.5658686 , -6.4699683 , -4.5086355 , ..., -0.50855184,
        -4.4462056 , -2.7241902 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.979:  33%|███▎      | 39/118 [00:01<00:02, 39.38it/s]

Tensor(array([[-3.438245  , -4.5563526 , -1.2734625 , ..., -2.7709947 ,
        -3.1805878 , -1.7381486 ],
       [-4.2596116 , -5.58304   , -2.135046  , ..., -5.2730484 ,
        -5.07626   , -4.4101996 ],
       [-4.9613285 , -0.6792519 , -2.0732965 , ..., -3.539424  ,
        -2.0716689 , -3.3851097 ],
       ...,
       [-3.272425  , -2.2241235 , -1.212667  , ..., -2.607577  ,
        -2.5073318 , -2.935998  ],
       [-5.3780036 , -0.32664967, -3.062801  , ..., -4.111044  ,
        -2.5479007 , -4.421505  ],
       [-3.8302257 , -5.570281  , -3.150683  , ..., -1.9472253 ,
        -3.7955024 , -1.0094483 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.979:  37%|███▋      | 44/118 [00:01<00:01, 37.87it/s]

Tensor(array([[-2.0258162, -4.3039603, -1.8601084, ..., -2.1792202, -3.3190472,
        -1.900324 ],
       [-4.2049274, -3.111203 , -0.6279185, ..., -4.49068  , -3.2394304,
        -3.8158865],
       [-3.3637486, -2.7111502, -3.1518018, ..., -2.0065203, -1.8020114,
        -2.147239 ],
       ...,
       [-0.5455613, -4.399634 , -2.1565638, ..., -2.7836688, -4.0290403,
        -3.9436157],
       [-4.733926 , -4.710164 , -2.0101483, ..., -5.831579 , -4.712881 ,
        -4.996951 ],
       [-5.7203636, -0.2795086, -2.966999 , ..., -4.061371 , -2.7160177,
        -4.556284 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 0.890:  41%|████      | 48/118 [00:01<00:01, 37.03it/s]

Tensor(array([[-3.6123288, -4.4832373, -3.570375 , ..., -0.8874881, -3.5601907,
        -1.3173522],
       [-0.8345463, -5.493716 , -2.6304374, ..., -3.786275 , -2.9446354,
        -4.2532167],
       [-3.4079847, -3.2206204, -4.3373003, ..., -1.308268 , -1.9141343,
        -2.0221634],
       ...,
       [-4.2549314, -1.3980958, -2.1529167, ..., -2.0729306, -1.8003678,
        -2.5193682],
       [-4.7378545, -3.4810684, -3.632469 , ..., -1.9132787, -2.12463  ,
        -1.3053426],
       [-2.6369085, -5.401043 , -1.829196 , ..., -5.8506136, -4.060412 ,
        -4.647091 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 1.004:  41%|████      | 48/118 [00:01<00:01, 37.03it/s]

Tensor(array([[-5.8457556, -5.8812103, -2.8996081, ..., -3.1481001, -5.056956 ,
        -1.5841002],
       [-1.3768271, -3.7913036, -2.0941074, ..., -4.3196855, -2.662718 ,
        -3.893301 ],
       [-3.5364056, -2.9596043, -2.4811893, ..., -5.340179 , -2.446733 ,
        -5.5546474],
       ...,
       [-4.317749 , -1.9881532, -2.1365352, ..., -4.1469665, -2.1778343,
        -3.5085175],
       [-3.8084137, -1.2015426, -1.2605549, ..., -4.6846685, -3.4027586,
        -5.057847 ],
       [-3.035293 , -2.609547 , -3.129882 , ..., -2.934917 , -2.0055676,
        -3.684861 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 0.884:  44%|████▍     | 52/118 [00:01<00:02, 28.60it/s]

Tensor(array([[-4.620876 , -1.5778366, -2.9814665, ..., -1.3112686, -2.2902985,
        -2.0224838],
       [-3.5381632, -5.038414 , -2.2455885, ..., -5.437585 , -4.4556956,
        -4.555897 ],
       [-2.1026518, -4.606325 , -4.447869 , ..., -1.8475586, -1.9454532,
        -2.233468 ],
       ...,
       [-4.3296523, -7.519948 , -3.4939702, ..., -2.975372 , -4.718328 ,
        -1.6265485],
       [-3.4555056, -3.5786383, -2.6644573, ..., -4.5222306, -1.090491 ,
        -3.7101767],
       [-5.504795 , -0.4237132, -3.058217 , ..., -3.8896182, -2.2413988,
        -4.0008016]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 0.881:  47%|████▋     | 56/118 [00:01<00:02, 27.63it/s]

Tensor(array([[-0.3428614 , -7.076581  , -3.4029174 , ..., -3.5100636 ,
        -3.5135612 , -4.0961947 ],
       [-4.7018538 , -4.7089767 , -2.1090672 , ..., -6.617507  ,
        -4.138936  , -5.486594  ],
       [-5.923719  , -0.24509025, -3.369028  , ..., -4.0101743 ,
        -2.8770006 , -4.560305  ],
       ...,
       [-4.6387334 , -7.9219546 , -5.419929  , ..., -1.8833171 ,
        -3.2509434 , -0.59597206],
       [-3.7363315 , -5.007121  , -4.456307  , ..., -0.48485065,
        -3.4188783 , -1.9442033 ],
       [-4.4759107 , -5.9131017 , -4.369913  , ..., -2.0044694 ,
        -3.0769806 , -0.7520089 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.906:  50%|█████     | 59/118 [00:01<00:02, 25.76it/s]

Tensor(array([[-5.259141  , -0.41185188, -2.3203018 , ..., -4.2561836 ,
        -2.5513215 , -4.3344736 ],
       [-3.5861135 , -4.595511  , -2.200788  , ..., -4.06103   ,
        -1.356302  , -3.3591626 ],
       [-4.384719  , -3.8029575 , -2.5680008 , ..., -2.5187798 ,
        -2.807691  , -1.64851   ],
       ...,
       [-3.2826312 , -4.5599103 , -3.1349387 , ..., -5.441438  ,
        -2.2831073 , -5.238351  ],
       [-4.6399217 , -4.5228004 , -2.715085  , ..., -2.229471  ,
        -3.6207552 , -1.3164592 ],
       [-0.0539875 , -9.284383  , -4.5895443 , ..., -6.0422335 ,
        -5.3259115 , -5.8042593 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.874:  58%|█████▊    | 68/118 [00:02<00:01, 32.75it/s]

Tensor(array([[-4.989224  , -0.49121404, -3.0072315 , ..., -2.7614827 ,
        -2.6248498 , -3.349996  ],
       [-2.0630686 , -5.5593233 , -4.121846  , ..., -0.6200037 ,
        -3.9778023 , -2.457153  ],
       [-5.231188  , -1.4223882 , -1.9092488 , ..., -4.6575356 ,
        -2.7987883 , -4.147027  ],
       ...,
       [-3.4219737 , -4.103824  , -2.7751482 , ..., -5.4036674 ,
        -2.6854954 , -5.384729  ],
       [-3.7983873 , -4.7608285 , -4.1613097 , ..., -4.675229  ,
        -1.9224062 , -5.067484  ],
       [-4.0517178 , -2.2803278 , -3.4117513 , ..., -1.461722  ,
        -2.1407626 , -1.7878929 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.810:  68%|██████▊   | 80/118 [00:02<00:00, 38.76it/s]

Tensor(array([[-4.203536  , -3.4940593 , -0.3163917 , ..., -4.453153  ,
        -3.5485213 , -4.4753237 ],
       [-4.128427  , -6.302945  , -5.274552  , ..., -0.8934438 ,
        -3.622414  , -0.92423487],
       [-4.69211   , -5.132433  , -3.3620267 , ..., -1.6647902 ,
        -2.763823  , -0.8626349 ],
       ...,
       [-5.5488324 , -2.4894345 , -0.6298764 , ..., -6.013768  ,
        -4.0549664 , -5.7370667 ],
       [-6.270397  , -7.164543  , -4.9309278 , ..., -2.8830478 ,
        -4.5824757 , -1.1411505 ],
       [-6.2373857 , -4.201743  , -1.8697298 , ..., -5.2163706 ,
        -4.158849  , -3.9980292 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.787:  74%|███████▎  | 87/118 [00:02<00:00, 42.50it/s]

Tensor(array([[-2.7530072 , -4.79311   , -2.1475291 , ..., -5.871726  ,
        -3.7353816 , -5.162513  ],
       [-3.6609735 , -3.9796402 , -4.98092   , ..., -1.3596028 ,
        -2.5320096 , -1.6071134 ],
       [-5.17943   , -1.7559311 , -3.3337286 , ..., -3.9782825 ,
        -0.64879704, -3.5606034 ],
       ...,
       [-4.278644  , -6.1093483 , -0.22970533, ..., -5.768489  ,
        -5.312022  , -4.897302  ],
       [-0.07447815, -8.703431  , -5.3755035 , ..., -5.2143164 ,
        -5.1412067 , -5.566089  ],
       [-5.353176  , -7.955125  , -6.641905  , ..., -0.14161682,
        -5.3419724 , -2.4401677 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.796:  84%|████████▍ | 99/118 [00:02<00:00, 39.20it/s]

Tensor(array([[-1.3270438 , -5.585017  , -2.9790475 , ..., -2.16882   ,
        -2.6174254 , -2.4624674 ],
       [-3.90336   , -5.161593  , -3.727059  , ..., -6.3361216 ,
        -3.0102348 , -6.582789  ],
       [-0.24284458, -6.237163  , -4.7852526 , ..., -4.09046   ,
        -4.1461363 , -4.7454476 ],
       ...,
       [-6.06089   , -7.6194468 , -5.8745394 , ..., -2.4370492 ,
        -3.597149  , -0.46527743],
       [-6.3199153 , -0.23365283, -3.9732478 , ..., -4.1102796 ,
        -2.6479597 , -4.5736094 ],
       [-4.0066967 , -5.0685563 , -2.1690104 , ..., -3.418629  ,
        -1.7144034 , -1.4027739 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.742:  89%|████████▉ | 105/118 [00:03<00:00, 42.67it/s]

Tensor(array([[-4.320874  , -4.096941  , -4.7944746 , ..., -0.5992081 ,
        -1.6849875 , -2.619483  ],
       [-7.1447077 , -7.9186816 , -5.241009  , ..., -1.8011806 ,
        -5.0221004 , -1.1659336 ],
       [-2.6810875 , -6.1463547 , -3.351257  , ..., -2.754031  ,
        -1.6641779 , -1.40589   ],
       ...,
       [-3.9560404 , -5.7540274 , -4.4078465 , ..., -1.4472208 ,
        -3.0074785 , -0.80335283],
       [-3.674017  , -5.087943  , -3.2575192 , ..., -5.9689765 ,
        -3.2177763 , -6.524892  ],
       [-2.2605202 , -3.4222882 , -1.226243  , ..., -5.3859987 ,
        -3.7142682 , -6.221308  ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.686:  97%|█████████▋| 115/118 [00:03<00:00, 42.36it/s]

Tensor(array([[-3.7634394 , -3.5895097 , -0.36458588, ..., -4.3739643 ,
        -3.8473258 , -5.2239695 ],
       [-4.573276  , -3.3980918 , -2.359505  , ..., -4.578832  ,
        -3.0665302 , -3.2653337 ],
       [-3.1910954 , -5.941681  , -3.0668526 , ..., -2.2398744 ,
        -2.1709428 , -3.0607393 ],
       ...,
       [-5.9586387 , -7.675467  , -4.5072236 , ..., -2.0690594 ,
        -4.4028015 , -0.5146773 ],
       [-6.581989  , -8.94303   , -6.522316  , ..., -1.3673563 ,
        -5.292364  , -0.7115631 ],
       [-5.208584  , -3.570916  , -3.0086465 , ..., -2.114963  ,
        -2.7274058 , -1.6755012 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.708: 100%|██████████| 118/118 [00:03<00:00, 35.72it/s]


Tensor(array([[-4.9088840e+00, -3.9932578e+00, -3.6408556e+00, -1.6843750e+00,
        -5.0433559e+00, -1.9938835e+00, -5.1723123e+00, -5.1879582e+00,
        -5.1931596e-01, -4.2984471e+00],
       [-4.3347893e+00, -5.1976533e+00, -3.5682704e+00, -3.9244056e+00,
        -2.1956937e+00, -2.5914097e+00, -5.1277390e+00, -1.7413317e+00,
        -1.7198876e+00, -9.4956660e-01],
       [-5.6056442e+00, -4.5663519e+00, -3.4871373e+00, -6.1877947e+00,
        -1.1252701e+00, -3.6829357e+00, -3.5714688e+00, -2.0576105e+00,
        -3.0991092e+00, -9.0979195e-01],
       [-4.0766582e+00, -5.1395607e+00, -2.7372880e+00, -3.3835530e-01,
        -6.8664427e+00, -2.7887061e+00, -6.2302971e+00, -5.7433100e+00,
        -2.0551715e+00, -5.5991259e+00],
       [-4.4415016e+00, -4.5180874e+00, -2.7755108e+00, -4.8274813e+00,
        -9.0225768e-01, -3.3842683e+00, -2.7608969e+00, -1.7839435e+00,
        -4.0254722e+00, -1.5215740e+00],
       [-3.3318748e+00, -2.3446851e+00, -2.0963130e+00, -1.0303681e+

Train set: Avg loss: 0.734:   0%|          | 0/118 [00:00<?, ?it/s]

Tensor(array([[-6.9170666 , -7.9764004 , -5.1716747 , ..., -2.225451  ,
        -4.601965  , -1.3937659 ],
       [-6.0462627 , -0.2045865 , -3.46354   , ..., -4.361751  ,
        -3.009234  , -5.390413  ],
       [-3.7138977 , -3.8050401 , -0.80492306, ..., -3.5037675 ,
        -3.0687346 , -3.1185498 ],
       ...,
       [-4.8376994 , -5.7563286 , -5.509284  , ..., -1.2894344 ,
        -1.8929417 , -1.0555842 ],
       [-4.2224956 , -4.077037  , -2.404843  , ..., -6.8148785 ,
        -2.9275289 , -7.000126  ],
       [-3.5131702 , -4.77991   , -3.703236  , ..., -2.718724  ,
        -1.4836555 , -1.8615123 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.737:   2%|▏         | 2/118 [00:00<00:07, 15.58it/s]

Tensor(array([[-5.011526  , -3.1373174 , -3.1468754 , ..., -0.5686586 ,
        -2.2923493 , -2.094079  ],
       [-5.033926  , -3.3883157 , -0.21030903, ..., -4.5582733 ,
        -3.9217615 , -5.624222  ],
       [-4.914484  , -1.971025  , -3.6166856 , ..., -1.1822593 ,
        -1.9958308 , -1.9089713 ],
       ...,
       [-4.994257  , -6.762512  , -4.5490656 , ..., -1.9758376 ,
        -2.6274693 , -0.8780184 ],
       [-4.98588   , -4.128069  , -4.0011945 , ..., -2.033176  ,
        -0.5919821 , -2.2979162 ],
       [-2.7205749 , -6.824281  , -6.7496696 , ..., -4.019126  ,
        -2.5335193 , -4.5019393 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.761:   7%|▋         | 8/118 [00:00<00:03, 35.72it/s]

Tensor(array([[-0.07722425, -9.904116  , -5.629866  , ..., -5.763419  ,
        -5.207774  , -5.8940325 ],
       [-5.974744  , -2.6494231 , -3.0449793 , ..., -5.6643763 ,
        -0.32639098, -4.2695417 ],
       [-6.345461  , -0.19708085, -3.5074778 , ..., -4.296147  ,
        -3.1044354 , -5.0413485 ],
       ...,
       [-1.0534081 , -7.052661  , -3.4874399 , ..., -2.8777027 ,
        -2.033043  , -2.4958048 ],
       [-3.640664  , -6.916149  , -0.44634342, ..., -7.0269756 ,
        -6.003587  , -6.2796288 ],
       [-6.7031345 , -0.25802994, -2.5294023 , ..., -4.5337086 ,
        -2.6021843 , -4.9126587 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.683:  10%|█         | 12/118 [00:00<00:03, 29.86it/s]

Tensor(array([[ -4.414774  ,  -8.360739  ,  -1.7193584 , ...,  -6.719219  ,
         -4.1162815 ,  -3.948312  ],
       [ -0.01443577, -11.053712  ,  -6.3780813 , ...,  -5.9301434 ,
         -7.5074034 ,  -7.001383  ],
       [ -5.0604677 ,  -5.710381  ,  -0.43354607, ...,  -4.0893645 ,
         -4.415464  ,  -2.9035752 ],
       ...,
       [ -6.4164963 ,  -0.63178825,  -1.8637161 , ...,  -4.6303396 ,
         -1.7125863 ,  -3.9667964 ],
       [ -3.3292432 ,  -5.8332505 ,  -2.9193664 , ...,  -6.6420093 ,
         -3.7631748 ,  -5.3023667 ],
       [ -1.6364896 ,  -6.7423177 ,  -0.3800521 , ...,  -6.478917  ,
         -5.294908  ,  -7.1722555 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.641:  14%|█▎        | 16/118 [00:00<00:03, 32.88it/s]

Tensor(array([[-6.6011214 , -7.9659605 , -3.1602454 , ..., -6.5767126 ,
        -7.3155055 , -5.2840686 ],
       [-5.3272786 , -7.305706  , -3.0459237 , ..., -6.8583274 ,
        -6.424337  , -5.460357  ],
       [-5.052193  , -8.773134  , -4.570035  , ..., -3.5739186 ,
        -5.245138  , -1.4791572 ],
       ...,
       [-5.5903215 , -1.1539085 , -0.99945354, ..., -5.1963167 ,
        -2.5526247 , -5.2182746 ],
       [-2.2769196 , -6.9446774 , -3.0239677 , ..., -6.8701057 ,
        -3.2572043 , -6.889597  ],
       [-6.0495205 , -5.5198293 , -4.266011  , ..., -2.350321  ,
        -3.7121408 , -1.3658504 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.707:  14%|█▎        | 16/118 [00:00<00:03, 32.88it/s]

Tensor(array([[-5.86428   , -0.8299084 , -2.7431002 , ..., -1.9511449 ,
        -2.4646833 , -2.1171436 ],
       [-5.1126657 , -4.1958685 , -5.1087737 , ..., -0.5224674 ,
        -3.7433443 , -1.3678463 ],
       [-6.6329403 , -8.537479  , -6.2073646 , ..., -2.8072402 ,
        -4.236554  , -1.1071239 ],
       ...,
       [-6.598609  , -0.825727  , -3.0060031 , ..., -4.249159  ,
        -1.344017  , -3.675797  ],
       [-2.6106768 , -3.0854435 , -1.5471761 , ..., -2.6423056 ,
        -3.1952715 , -3.392672  ],
       [-5.2923107 , -4.8606167 , -4.12296   , ..., -3.6744761 ,
        -0.28333235, -2.5779002 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.700:  17%|█▋        | 20/118 [00:00<00:03, 29.98it/s]

Tensor(array([[-6.045681  , -0.35591245, -3.0697246 , ..., -5.908642  ,
        -2.485085  , -5.8125753 ],
       [-5.1259146 , -1.0609505 , -2.2972956 , ..., -4.830302  ,
        -1.1138461 , -4.257304  ],
       [-5.1359663 , -2.3652701 , -4.3854375 , ..., -1.8971367 ,
        -1.8792754 , -1.1624963 ],
       ...,
       [-4.947     , -6.0394497 , -5.7144213 , ..., -0.6691475 ,
        -4.0665326 , -1.106288  ],
       [-3.625485  , -4.4024844 , -4.528129  , ..., -3.7941024 ,
        -0.8778024 , -3.4984517 ],
       [-1.2376494 , -6.7804527 , -2.9356108 , ..., -3.1991065 ,
        -2.733412  , -3.91307   ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.680:  20%|██        | 24/118 [00:00<00:03, 31.32it/s]

Tensor(array([[ -0.04295731, -12.088636  ,  -5.432996  , ...,  -6.3086424 ,
         -5.7850885 ,  -6.26523   ],
       [ -0.37240052,  -6.803814  ,  -3.466735  , ...,  -4.9141083 ,
         -3.385458  ,  -5.182271  ],
       [ -3.7806895 ,  -9.498848  ,  -3.3951712 , ...,  -1.8226767 ,
         -5.584216  ,  -0.9630742 ],
       ...,
       [ -6.0680842 ,  -5.751613  ,  -5.5161314 , ...,  -0.20290089,
         -4.394828  ,  -2.2182505 ],
       [ -5.441062  ,  -6.4487987 ,  -4.1614623 , ...,  -2.4781003 ,
         -3.5462592 ,  -0.614171  ],
       [ -4.218356  ,  -4.1644306 ,  -0.30679417, ...,  -3.9321122 ,
         -3.1750314 ,  -3.166432  ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.710:  24%|██▎       | 28/118 [00:00<00:03, 28.93it/s]

Tensor(array([[-3.3492453, -7.866652 , -4.692112 , ..., -5.4057174, -3.3734925,
        -6.841529 ],
       [-3.199812 , -4.1536655, -2.1298153, ..., -3.0297515, -1.7076781,
        -3.0639632],
       [-3.8665109, -4.786989 , -3.3034463, ..., -6.067705 , -1.3325249,
        -3.8104615],
       ...,
       [-4.33752  , -2.890816 , -2.4986556, ..., -5.6251574, -0.7991655,
        -4.2622313],
       [-1.955779 , -4.654697 , -2.546882 , ..., -5.333705 , -2.1011043,
        -5.689881 ],
       [-5.274743 , -1.6312149, -4.4352894, ..., -4.3615046, -0.9885938,
        -4.8769674]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 0.660:  24%|██▎       | 28/118 [00:01<00:03, 28.93it/s]

Tensor(array([[-3.3982399 , -9.410258  , -7.275608  , ..., -0.11298323,
        -5.9761424 , -3.3755302 ],
       [-5.529643  , -3.9464374 , -2.1338599 , ..., -6.9743814 ,
        -4.199121  , -6.62211   ],
       [-2.3402376 , -4.771491  , -3.2110133 , ..., -4.426289  ,
        -3.3465662 , -6.005461  ],
       ...,
       [-4.1478596 , -2.3414521 , -1.7908716 , ..., -5.4246016 ,
        -1.76272   , -5.364062  ],
       [-4.095356  , -3.5798154 , -2.0434313 , ..., -4.7167635 ,
        -3.3139372 , -6.1664996 ],
       [-8.317374  , -8.613347  , -6.1141224 , ..., -3.7307825 ,
        -5.8959446 , -1.8228915 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.666:  27%|██▋       | 32/118 [00:01<00:02, 30.52it/s]

Tensor(array([[ -7.606336  ,  -9.655768  ,  -5.972499  , ...,  -4.03768   ,
         -5.6794906 ,  -2.0782294 ],
       [ -2.1852765 ,  -8.413408  ,  -2.239633  , ...,  -3.7408917 ,
         -5.193411  ,  -3.392324  ],
       [ -6.0700054 ,  -0.6281717 ,  -2.9784584 , ...,  -4.695236  ,
         -1.2797229 ,  -4.83686   ],
       ...,
       [ -0.03800535, -10.572154  ,  -6.593682  , ...,  -6.894862  ,
         -6.4931345 ,  -7.380997  ],
       [ -7.414917  ,  -0.08969355,  -3.6837325 , ...,  -5.633858  ,
         -4.101844  ,  -6.5892076 ],
       [ -0.51108503, -10.4133005 ,  -6.296123  , ...,  -6.0512214 ,
         -2.697807  ,  -5.3108206 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.605:  40%|███▉      | 47/118 [00:01<00:01, 37.47it/s]

Tensor(array([[ -8.539593  ,  -0.08479118,  -3.5176454 , ...,  -6.4997025 ,
         -3.8782053 ,  -6.816507  ],
       [ -6.8745203 ,  -0.14163446,  -3.797028  , ...,  -4.6560807 ,
         -3.2899714 ,  -5.3562875 ],
       [ -0.04868221, -10.72512   ,  -5.3027163 , ...,  -7.962989  ,
         -5.822192  ,  -8.133045  ],
       ...,
       [ -4.8612375 ,  -6.7242594 ,  -2.761711  , ...,  -7.3693295 ,
         -5.6897864 ,  -5.8807526 ],
       [ -2.8261352 ,  -6.875721  ,  -6.7657585 , ...,  -3.7821622 ,
         -2.0921628 ,  -4.000117  ],
       [ -6.1922064 ,  -8.267229  ,  -4.35011   , ...,  -5.735079  ,
         -2.441811  ,  -1.1697781 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.680:  43%|████▎     | 51/118 [00:01<00:02, 33.43it/s]

Tensor(array([[-6.753042 , -6.7074695, -3.84117  , ..., -8.697667 , -5.271836 ,
        -6.3506026],
       [-2.7830567, -5.698336 , -4.844146 , ..., -3.8755999, -2.8008637,
        -3.400423 ],
       [-3.5302854, -2.0731494, -4.5527625, ..., -3.3713477, -2.118371 ,
        -4.8163304],
       ...,
       [-4.582694 , -8.414324 , -2.0539207, ..., -8.546383 , -6.2110596,
        -6.0574183],
       [-6.79792  , -0.7363391, -3.3428175, ..., -4.5452585, -1.3407593,
        -3.8454525],
       [-5.3281727, -9.2141905, -3.310097 , ..., -7.9807715, -7.8003893,
        -6.5638967]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65, 

Train set: Avg loss: 0.624:  53%|█████▎    | 63/118 [00:01<00:01, 43.94it/s]

Tensor(array([[-0.51770663, -6.989003  , -5.4040728 , ..., -2.439084  ,
        -4.0720506 , -4.4107857 ],
       [-4.54331   , -7.950429  , -6.244297  , ..., -0.39366913,
        -4.4319024 , -1.3531208 ],
       [-5.4363394 , -5.6716547 , -3.5494134 , ..., -4.563806  ,
        -0.32299876, -2.387473  ],
       ...,
       [-5.5186696 , -6.255926  , -5.730344  , ..., -0.15037346,
        -5.48594   , -2.41381   ],
       [-3.9275432 , -9.465769  , -7.637615  , ..., -0.50061274,
        -5.625482  , -1.3480916 ],
       [-5.703343  , -5.1965475 , -0.15860343, ..., -6.0376406 ,
        -2.9135904 , -3.89172   ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.533:  62%|██████▏   | 73/118 [00:02<00:01, 39.13it/s]

Tensor(array([[ -5.961698  ,  -5.142013  ,  -5.6653543 , ...,  -0.33103895,
         -4.3524346 ,  -1.6110401 ],
       [ -7.730416  ,  -7.0343385 ,  -2.9408693 , ...,  -7.970463  ,
         -4.3413916 ,  -4.712897  ],
       [ -2.348289  ,  -6.873563  ,  -5.000752  , ...,  -0.25086594,
         -5.5974846 ,  -3.5298517 ],
       ...,
       [ -2.2772965 ,  -6.550171  ,  -0.61583734, ...,  -5.000671  ,
         -3.770987  ,  -4.6358066 ],
       [ -7.5502357 ,  -9.791108  ,  -3.3548129 , ...,  -8.224701  ,
         -5.667251  ,  -4.833455  ],
       [ -8.334479  , -10.028528  ,  -3.6701307 , ..., -10.2524395 ,
         -8.070004  ,  -7.6037183 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.587:  76%|███████▋  | 90/118 [00:02<00:00, 43.62it/s]

Tensor(array([[-6.603922  , -5.157298  , -5.889664  , ..., -2.0888786 ,
        -1.4669614 , -0.7190552 ],
       [-5.9372854 , -1.6008708 , -2.4492517 , ..., -5.419811  ,
        -0.6337714 , -4.7069674 ],
       [-5.3285136 , -6.1344156 , -0.21830583, ..., -5.4676166 ,
        -2.885204  , -3.1850157 ],
       ...,
       [-8.440032  , -0.13152361, -4.499471  , ..., -5.638616  ,
        -2.6884804 , -6.0238085 ],
       [-5.4971232 , -4.5950313 , -1.0659387 , ..., -6.964465  ,
        -1.1391997 , -5.205436  ],
       [-7.299489  , -4.3599324 , -4.076544  , ..., -3.881868  ,
        -3.4738965 , -1.9093142 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.591:  81%|████████  | 95/118 [00:02<00:00, 40.62it/s]

Tensor(array([[-5.5839796 , -0.92427635, -3.3221655 , ..., -4.596494  ,
        -1.1603172 , -4.5955257 ],
       [-4.588188  , -6.1189184 , -3.8182244 , ..., -3.903793  ,
        -1.2358694 , -2.6091595 ],
       [-5.32812   , -5.9333916 , -4.277334  , ..., -2.0214467 ,
        -0.9020839 , -1.6534897 ],
       ...,
       [-4.1145535 , -6.0915527 , -5.5677385 , ..., -5.3354273 ,
        -2.5317514 , -5.9127245 ],
       [-6.285183  , -1.6663969 , -0.5496824 , ..., -6.5414267 ,
        -1.9780166 , -5.676544  ],
       [-3.908868  , -9.2992525 , -7.763333  , ..., -0.1960144 ,
        -6.40738   , -2.2688508 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.574:  89%|████████▉ | 105/118 [00:02<00:00, 36.02it/s]

Tensor(array([[ -5.318746  ,  -7.063119  ,  -6.3033285 , ...,  -0.91640663,
         -3.586821  ,  -0.7882583 ],
       [ -4.263942  ,  -8.867464  ,  -6.705754  , ...,  -1.7143955 ,
         -4.3653417 ,  -0.5403607 ],
       [ -6.2439475 ,  -7.016243  ,  -5.576329  , ...,  -6.3357944 ,
         -0.3045783 ,  -4.214018  ],
       ...,
       [ -2.6592016 ,  -5.583261  ,  -4.8386316 , ...,  -0.7326131 ,
         -3.9323626 ,  -2.4895344 ],
       [ -5.3393106 , -10.505097  ,  -8.058532  , ...,  -0.04787445,
         -8.126974  ,  -3.4173143 ],
       [ -8.6451435 ,  -6.9493384 ,  -0.03888655, ..., -10.131264  ,
         -6.133037  ,  -8.155402  ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.536:  92%|█████████▏| 109/118 [00:03<00:00, 32.72it/s]

Tensor(array([[ -6.254934  ,  -5.9928894 ,  -0.03129625, ...,  -8.102425  ,
         -4.651877  ,  -7.0363026 ],
       [ -3.9854655 ,  -8.773215  ,  -6.059326  , ...,  -3.5006087 ,
         -4.115113  ,  -1.5896297 ],
       [ -6.0058064 ,  -6.6917214 ,  -6.1862316 , ...,  -0.16651535,
         -4.9114184 ,  -2.1316626 ],
       ...,
       [ -7.8675976 ,  -6.212064  ,  -0.05940342, ...,  -9.63975   ,
         -7.319923  ,  -8.354972  ],
       [ -8.525058  , -11.406361  ,  -5.4778852 , ...,  -6.15244   ,
         -6.7373233 ,  -3.2843618 ],
       [ -6.6906633 ,  -4.735227  ,  -0.0496974 , ...,  -6.998575  ,
         -6.0678163 ,  -7.6906104 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.495: 100%|██████████| 118/118 [00:03<00:00, 35.49it/s]


Tensor(array([[ -7.7672243 ,  -9.027751  ,  -7.6147223 , ...,  -3.2176592 ,
         -4.0527725 ,  -1.2322006 ],
       [ -0.07153082, -10.971799  ,  -5.396838  , ...,  -4.182137  ,
         -5.688326  ,  -6.1516137 ],
       [ -8.071026  ,  -9.35034   ,  -6.050942  , ...,  -4.0958714 ,
         -6.485668  ,  -2.5543563 ],
       ...,
       [ -3.3031707 ,  -5.2105165 ,  -2.7663183 , ...,  -7.4996157 ,
         -3.7788293 ,  -6.099396  ],
       [ -2.2150068 ,  -4.083897  ,  -3.600099  , ...,  -1.7991999 ,
         -3.453118  ,  -3.187276  ],
       [ -5.174509  ,  -3.0968823 ,  -2.1901171 , ...,  -6.491061  ,
         -0.57133317,  -4.7496395 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.588:   0%|          | 0/118 [00:00<?, ?it/s]

Tensor(array([[ -8.713245  ,  -9.5415945 ,  -6.928442  , ...,  -4.8493295 ,
         -4.8090897 ,  -0.9524808 ],
       [ -0.01681471, -12.482321  ,  -5.542584  , ...,  -5.3673496 ,
         -8.933785  ,  -6.639588  ],
       [ -6.579398  ,  -7.6656976 ,  -4.332069  , ...,  -0.83863664,
         -5.177892  ,  -0.86878896],
       ...,
       [ -1.8856382 ,  -8.049534  ,  -3.7164087 , ...,  -6.1875973 ,
         -5.85909   ,  -6.153358  ],
       [ -6.265541  ,  -7.9796867 ,  -4.9925337 , ...,  -4.314353  ,
         -3.616584  ,  -1.3792288 ],
       [ -0.22321272,  -7.8146744 ,  -4.3716717 , ...,  -5.1581993 ,
         -6.1268816 ,  -6.4832    ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.505:   2%|▏         | 2/118 [00:00<00:06, 17.53it/s]

Tensor(array([[-7.357764  , -0.12579632, -4.025055  , ..., -3.9951136 ,
        -3.9211638 , -4.9407024 ],
       [-5.153817  , -4.8693113 , -5.6714115 , ..., -1.6726607 ,
        -3.2457082 , -0.98472285],
       [-7.0874414 , -9.621571  , -4.4987283 , ..., -3.9105272 ,
        -6.209318  , -1.946356  ],
       ...,
       [-6.605731  , -4.612789  , -4.968876  , ..., -2.9549732 ,
        -1.3662899 , -0.90649605],
       [-7.9271574 , -0.15357447, -2.5834162 , ..., -5.3745723 ,
        -3.5078897 , -5.89386   ],
       [-4.7330174 , -3.6823826 , -0.8978877 , ..., -5.872904  ,
        -5.462292  , -5.9929314 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.504:   2%|▏         | 2/118 [00:00<00:06, 17.53it/s]

Tensor(array([[ -9.205053  , -11.285825  ,  -7.975714  , ...,  -6.2142467 ,
         -6.060459  ,  -2.8696764 ],
       [ -7.226981  ,  -5.524188  ,  -3.2355118 , ...,  -8.81943   ,
         -4.663638  ,  -6.9298368 ],
       [ -7.1700287 ,  -7.7859955 ,  -5.2681236 , ...,  -8.227509  ,
         -0.26026917,  -5.82731   ],
       ...,
       [ -8.386849  , -12.098843  ,  -4.3387756 , ..., -12.115995  ,
         -8.69101   ,  -8.973318  ],
       [ -3.0332232 ,  -5.2589645 ,  -4.763423  , ...,  -5.035024  ,
         -2.8596573 ,  -4.7110443 ],
       [ -7.4516735 ,  -8.9706135 ,  -7.506381  , ...,  -0.0502882 ,
         -7.391548  ,  -3.1344168 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.516:  11%|█         | 13/118 [00:00<00:03, 34.49it/s]

Tensor(array([[-1.6369736 , -7.9628305 , -1.734484  , ..., -8.345445  ,
        -4.247659  , -8.007153  ],
       [-2.2407994 , -5.720412  , -4.479504  , ..., -3.2242897 ,
        -1.459138  , -1.3734093 ],
       [-8.424251  , -0.17782307, -3.6089187 , ..., -6.607287  ,
        -2.3345854 , -6.5176287 ],
       ...,
       [-5.9990234 , -5.9326353 , -6.004071  , ..., -2.6731582 ,
        -0.59104395, -1.2899537 ],
       [-2.2477024 , -8.652266  , -6.2121043 , ..., -4.7075987 ,
        -4.001525  , -6.6141915 ],
       [-6.760827  , -3.6728625 , -0.19939232, ..., -4.8175545 ,
        -4.3993745 , -4.5494556 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.557:  14%|█▍        | 17/118 [00:00<00:03, 28.90it/s]

Tensor(array([[ -3.815892  ,  -7.0974927 ,  -3.2621522 , ...,  -7.035719  ,
         -7.19236   ,  -7.2820106 ],
       [ -3.6240022 ,  -7.338462  ,  -4.4545884 , ...,  -3.1083872 ,
         -1.699075  ,  -1.3562975 ],
       [ -2.8366942 ,  -3.0446851 ,  -3.178639  , ...,  -3.2732847 ,
         -3.4746423 ,  -5.247756  ],
       ...,
       [ -0.0198679 , -13.109419  ,  -8.730054  , ...,  -7.0752654 ,
         -8.343762  ,  -8.643741  ],
       [ -0.03666735, -12.578131  ,  -8.136402  , ...,  -8.182655  ,
         -7.6677914 ,  -8.806158  ],
       [ -8.532345  ,  -0.0607233 ,  -4.3527856 , ...,  -7.19158   ,
         -3.9884522 ,  -7.8596163 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.554:  20%|██        | 24/118 [00:00<00:02, 38.76it/s]

Tensor(array([[ -7.165265  ,  -6.145916  ,  -4.5677767 , ...,  -8.08251   ,
         -1.7793639 ,  -7.094624  ],
       [ -3.5105624 ,  -8.596699  ,  -6.6885324 , ...,  -6.5646048 ,
         -2.0335073 ,  -4.5230374 ],
       [ -5.7813606 , -10.552357  ,  -4.535475  , ...,  -3.3767054 ,
         -5.8775225 ,  -1.9721634 ],
       ...,
       [ -5.670783  ,  -5.5747023 ,  -4.784833  , ...,  -1.6764616 ,
         -3.781909  ,  -0.68678355],
       [ -3.9385142 ,  -9.686451  ,  -6.524298  , ...,  -1.908803  ,
         -3.1378403 ,  -0.4186375 ],
       [ -6.08881   ,  -1.7403963 ,  -3.4101198 , ...,  -5.532319  ,
         -0.49019623,  -4.7979946 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.532:  20%|██        | 24/118 [00:00<00:02, 38.76it/s]

Tensor(array([[ -9.4159565 , -11.043467  ,  -8.1766405 , ...,  -4.6417675 ,
         -6.3747406 ,  -1.8149128 ],
       [ -8.065692  , -10.7155485 ,  -8.582424  , ...,  -0.04582787,
         -8.576589  ,  -3.1802182 ],
       [ -6.6949353 ,  -8.401638  ,  -5.680974  , ...,  -0.19154263,
         -5.3854227 ,  -1.940434  ],
       ...,
       [ -7.4360533 ,  -5.4448013 ,  -3.2309856 , ...,  -4.909603  ,
         -3.7526493 ,  -2.9026182 ],
       [ -3.8472977 ,  -6.898607  ,  -2.310719  , ...,  -7.691266  ,
         -4.600248  ,  -8.695776  ],
       [ -8.429071  ,  -0.09853697,  -4.9376755 , ...,  -6.6140585 ,
         -3.0783718 ,  -7.0322065 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.537:  20%|██        | 24/118 [00:00<00:02, 38.76it/s]

Tensor(array([[ -3.8660345 , -10.184732  ,  -6.2681894 , ...,  -4.6975513 ,
         -5.329393  ,  -2.1404457 ],
       [ -4.3844094 ,  -6.806714  ,  -4.5551186 , ...,  -8.564277  ,
         -0.6297121 ,  -5.650176  ],
       [ -5.2196608 ,  -4.9184914 ,  -5.914656  , ...,  -6.458785  ,
         -0.66685367,  -4.4746604 ],
       ...,
       [ -6.0519056 ,  -7.3884983 ,  -5.003675  , ...,  -5.8526025 ,
         -0.22115183,  -4.2387686 ],
       [ -8.2450905 ,  -2.0431583 ,  -1.4043524 , ...,  -7.9583416 ,
         -3.0616355 ,  -6.3694763 ],
       [ -7.089872  ,  -9.513503  ,  -8.638938  , ...,  -0.02491379,
         -7.2014065 ,  -3.9905584 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.592:  29%|██▉       | 34/118 [00:01<00:02, 33.57it/s]

Tensor(array([[ -8.6854515 , -12.82705   ,  -8.110715  , ...,  -5.9832063 ,
         -5.375145  ,  -1.3951597 ],
       [ -8.844959  ,  -0.18829918,  -2.5851362 , ...,  -5.5429907 ,
         -2.7106993 ,  -6.024728  ],
       [ -6.6860056 , -10.99884   ,  -4.107281  , ...,  -5.6646376 ,
         -4.466213  ,  -2.2068372 ],
       ...,
       [ -3.243732  ,  -8.321044  ,  -6.492485  , ...,  -3.2928972 ,
         -4.003598  ,  -1.0246835 ],
       [ -3.0674381 ,  -7.2033787 ,  -0.46115017, ...,  -7.8658323 ,
         -6.784874  ,  -8.228863  ],
       [ -5.898233  , -12.304323  ,  -7.783549  , ...,  -3.403749  ,
         -4.8268557 ,  -0.1520772 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.488:  29%|██▉       | 34/118 [00:01<00:02, 33.57it/s]

Tensor(array([[-4.1546593e+00, -5.1046858e+00, -5.7451410e+00, ...,
        -4.3358898e+00, -2.2980204e+00, -3.8224034e+00],
       [-9.4931355e+00, -9.3569756e-02, -2.9941447e+00, ...,
        -8.4614944e+00, -4.0831919e+00, -8.4933510e+00],
       [-5.8611431e+00, -6.4189863e+00, -3.6288178e+00, ...,
        -1.9962801e+00, -4.8709984e+00, -1.3272154e+00],
       ...,
       [-4.2572904e+00, -7.6003556e+00, -6.4544725e+00, ...,
        -6.0379496e+00, -2.7997451e+00, -5.6362553e+00],
       [-6.6305923e+00, -8.3612680e+00, -6.9187365e+00, ...,
        -3.1059666e+00, -3.4114852e+00, -2.8664827e-01],
       [-1.4352798e-03, -1.6798038e+01, -9.1685829e+00, ...,
        -9.3491011e+00, -9.4285288e+00, -9.7915478e+00]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39, 

Train set: Avg loss: 0.520:  32%|███▏      | 38/118 [00:01<00:02, 33.02it/s]

Tensor(array([[ -7.9383626 , -12.307114  , -11.127214  , ...,  -0.31847286,
         -7.427077  ,  -1.4315042 ],
       [ -3.7740886 ,  -5.9788256 ,  -5.969533  , ...,  -4.4532166 ,
         -2.1612215 ,  -3.0051394 ],
       [ -4.552305  ,  -5.265687  ,  -3.8167634 , ...,  -5.513995  ,
         -3.1666446 ,  -6.653942  ],
       ...,
       [ -5.2243543 , -11.873003  ,  -9.700891  , ...,  -0.15111256,
         -7.061238  ,  -2.2748244 ],
       [ -3.252081  ,  -4.5816946 ,  -3.422635  , ...,  -3.4088848 ,
         -1.5476714 ,  -3.3877141 ],
       [ -0.02186489, -14.19519   ,  -8.137954  , ...,  -9.232138  ,
         -7.899213  ,  -8.64708   ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.505:  36%|███▌      | 42/118 [00:01<00:02, 29.48it/s]

Tensor(array([[ -5.6314707 ,  -2.592771  ,  -2.647323  , ...,  -6.7085557 ,
         -1.8969722 ,  -7.0342655 ],
       [ -8.139145  ,  -0.11227894,  -3.6179786 , ...,  -6.7732816 ,
         -3.1517406 ,  -7.0261583 ],
       [ -6.661315  ,  -8.540424  ,  -4.146763  , ...,  -8.475158  ,
         -6.330221  ,  -6.700385  ],
       ...,
       [ -6.7533517 , -11.47239   ,  -4.3102174 , ..., -10.298978  ,
         -8.785286  ,  -7.8532343 ],
       [ -7.8761086 ,  -5.297795  ,  -0.03461695, ...,  -9.533636  ,
         -5.8831835 ,  -9.671659  ],
       [ -6.0319605 ,  -7.428278  ,  -5.8626847 , ...,  -7.8266144 ,
         -0.14594746,  -4.7329082 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.512:  36%|███▌      | 42/118 [00:01<00:02, 29.48it/s]

Tensor(array([[-8.4666595e+00, -9.7830620e+00, -4.2490778e+00, ...,
        -1.2512463e+01, -5.5657711e+00, -8.3579454e+00],
       [-2.9164357e+00, -5.7978888e+00, -3.2512941e+00, ...,
        -7.3826246e+00, -2.8979254e+00, -6.5011482e+00],
       [-9.0878344e+00, -2.6426888e-01, -2.3822389e+00, ...,
        -5.4519987e+00, -2.2526762e+00, -5.4403672e+00],
       ...,
       [-7.5010529e+00, -8.4910450e+00, -6.1532583e+00, ...,
        -5.3770366e+00, -4.6382127e+00, -2.8878250e+00],
       [-6.7597923e+00, -3.9742379e+00, -3.1677432e+00, ...,
        -6.6908836e+00, -1.5103054e-01, -5.5864005e+00],
       [-2.8944016e-03, -1.7156126e+01, -7.4197412e+00, ...,
        -8.3018646e+00, -8.7021170e+00, -7.8876033e+00]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39, 

Train set: Avg loss: 0.525:  39%|███▉      | 46/118 [00:01<00:02, 26.07it/s]

Tensor(array([[-3.1957998 , -6.864992  , -1.8113747 , ..., -6.6270456 ,
        -2.690847  , -3.1682374 ],
       [-3.202267  , -9.411995  , -7.89871   , ..., -3.7556205 ,
        -3.7187896 , -2.774705  ],
       [-8.810768  , -0.18992901, -3.2480257 , ..., -6.774148  ,
        -2.3057756 , -6.4063935 ],
       ...,
       [-4.1283846 , -1.9641038 , -1.7670605 , ..., -3.5220926 ,
        -3.7597947 , -4.4015694 ],
       [-7.868569  , -0.1351285 , -3.995274  , ..., -3.3964858 ,
        -4.1803837 , -4.801048  ],
       [-4.972215  , -8.640361  , -5.395943  , ..., -6.233799  ,
        -0.19743443, -2.7775173 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.528:  42%|████▏     | 49/118 [00:01<00:02, 25.66it/s]

Tensor(array([[ -6.8201637 ,  -8.465273  ,  -6.1670656 , ...,  -9.262152  ,
         -0.14296961,  -6.1095233 ],
       [ -9.031437  ,  -0.0366087 ,  -4.5473776 , ...,  -6.442141  ,
         -5.023693  ,  -7.294113  ],
       [ -4.1260734 ,  -4.0708222 ,  -3.8753905 , ...,  -2.716938  ,
         -2.1979094 ,  -4.0799603 ],
       ...,
       [ -9.277003  , -13.378405  ,  -9.7717085 , ...,  -5.370104  ,
         -6.9488125 ,  -2.3850274 ],
       [ -6.1013803 ,  -5.334731  ,  -6.1026764 , ...,  -1.0660019 ,
         -2.570765  ,  -1.3797693 ],
       [ -7.4378557 ,  -7.9388824 ,  -7.4444785 , ...,  -2.0327382 ,
         -2.6570225 ,  -0.3991723 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.498:  44%|████▍     | 52/118 [00:01<00:02, 25.79it/s]

Tensor(array([[ -7.900443  ,  -9.893706  ,  -4.764699  , ..., -11.663204  ,
         -7.93684   ,  -9.048399  ],
       [ -8.406994  ,  -6.2101803 ,  -5.196739  , ...,  -7.488053  ,
         -0.05592871,  -5.2210603 ],
       [ -3.0376048 , -11.838264  , -10.100576  , ...,  -0.20412302,
         -6.9209566 ,  -2.6080475 ],
       ...,
       [ -2.8498266 ,  -9.083674  ,  -5.057904  , ...,  -5.543793  ,
         -5.256638  ,  -8.610743  ],
       [ -7.161862  ,  -8.205935  ,  -3.6313646 , ...,  -6.548029  ,
         -6.5207424 ,  -4.89966   ],
       [ -7.784094  ,  -4.5692396 ,  -5.599723  , ...,  -7.7467184 ,
         -0.0707221 ,  -5.185056  ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.556:  47%|████▋     | 55/118 [00:01<00:02, 25.45it/s]

Tensor(array([[ -4.737878  ,  -4.7044444 ,  -2.9272103 , ...,  -7.949223  ,
         -1.3556921 ,  -5.6797123 ],
       [ -4.126856  ,  -5.338813  ,  -3.4962184 , ...,  -1.234335  ,
         -3.6463182 ,  -1.1297145 ],
       [ -8.52523   , -12.177396  ,  -7.682827  , ...,  -3.7515726 ,
         -6.7035413 ,  -1.1476717 ],
       ...,
       [ -5.5241995 ,  -6.3738503 ,  -4.397642  , ...,  -1.0400367 ,
         -4.020339  ,  -1.2807798 ],
       [ -8.758766  ,  -8.992627  ,  -6.172684  , ...,  -5.1390853 ,
         -1.7771287 ,  -0.52790165],
       [ -4.85185   ,  -8.200217  ,  -6.006732  , ...,  -9.261877  ,
         -3.3985126 ,  -8.189007  ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.500:  58%|█████▊    | 68/118 [00:02<00:01, 30.20it/s]

Tensor(array([[ -2.6202636 ,  -5.644156  ,  -4.248241  , ...,  -6.9664235 ,
         -2.4046848 ,  -6.920378  ],
       [ -8.965254  ,  -8.497764  ,  -0.04573059, ..., -11.14221   ,
         -8.069796  ,  -9.136386  ],
       [ -7.803292  , -11.777763  ,  -8.421495  , ...,  -0.02800274,
         -9.3847685 ,  -3.6613808 ],
       ...,
       [ -5.141767  ,  -8.395867  ,  -4.631598  , ...,  -7.30736   ,
         -3.5969362 ,  -7.32857   ],
       [ -6.856947  ,  -3.562699  ,  -3.9334633 , ...,  -1.4494628 ,
         -4.127499  ,  -1.0157821 ],
       [ -0.4335966 ,  -8.911261  ,  -7.460803  , ...,  -6.6684513 ,
         -4.9423256 ,  -6.9254465 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.482:  63%|██████▎   | 74/118 [00:02<00:01, 36.44it/s]

Tensor(array([[-4.1410317 , -3.8402553 , -0.26369882, ..., -4.7491856 ,
        -2.2891514 , -4.793445  ],
       [-4.154117  , -8.400584  , -7.43809   , ..., -6.6827445 ,
        -2.3446977 , -4.068849  ],
       [-6.8560753 , -7.179066  , -4.2920156 , ..., -7.482045  ,
        -0.11776257, -3.628481  ],
       ...,
       [-3.335659  , -6.288715  , -1.4440401 , ..., -7.953074  ,
        -5.117411  , -7.285031  ],
       [-7.8083754 , -0.22285652, -2.8838701 , ..., -2.7234223 ,
        -3.8079524 , -4.1211667 ],
       [-0.13591146, -7.9943504 , -2.8910246 , ..., -5.8785877 ,
        -5.3470764 , -7.22573   ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  

Train set: Avg loss: 0.550:  71%|███████   | 84/118 [00:02<00:00, 40.20it/s]

Tensor(array([[ -4.9229627 ,  -4.78775   ,  -2.0944068 , ...,  -6.5077486 ,
         -0.34518433,  -5.591067  ],
       [ -0.5340719 ,  -7.0118647 ,  -2.4458697 , ...,  -5.1409035 ,
         -4.072274  ,  -6.440177  ],
       [ -8.473354  , -13.557873  ,  -8.824217  , ...,  -4.332173  ,
         -6.1923347 ,  -0.11294031],
       ...,
       [ -9.066347  , -13.533726  ,  -9.338709  , ...,  -6.956356  ,
         -5.370469  ,  -2.2094693 ],
       [ -1.5231919 ,  -9.542147  ,  -3.7528172 , ...,  -6.9052086 ,
         -6.885577  ,  -6.3982034 ],
       [ -0.4304471 ,  -7.64555   ,  -1.2499118 , ...,  -5.230608  ,
         -4.9447327 ,  -6.2684584 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.467:  85%|████████▍ | 100/118 [00:03<00:00, 44.09it/s]

Tensor(array([[ -9.089392  ,  -7.7571282 ,  -0.03136444, ..., -11.028791  ,
         -6.2897434 ,  -8.441237  ],
       [ -5.6850014 ,  -9.149522  ,  -7.088815  , ...,  -0.2678237 ,
         -6.261397  ,  -1.601125  ],
       [ -6.537759  ,  -5.587435  ,  -3.0897856 , ...,  -7.1621842 ,
         -5.8958864 ,  -6.9056845 ],
       ...,
       [ -8.415758  , -13.268588  ,  -7.7109175 , ...,  -6.7826104 ,
         -5.4241056 ,  -1.9178987 ],
       [ -7.780764  ,  -8.914963  ,  -5.499035  , ...,  -3.518642  ,
         -5.5795326 ,  -1.4921145 ],
       [ -6.611562  ,  -1.1514578 ,  -3.4709895 , ...,  -5.943383  ,
         -0.8030257 ,  -5.194641  ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.470:  89%|████████▉ | 105/118 [00:03<00:00, 34.29it/s]

Tensor(array([[ -9.716406  ,  -0.02920961,  -4.5162554 , ...,  -6.7347274 ,
         -5.177143  ,  -7.8665624 ],
       [ -7.9180155 , -11.118706  ,  -7.553564  , ...,  -4.5349903 ,
         -5.607207  ,  -1.1123333 ],
       [ -5.6604805 ,  -5.5224547 ,  -1.3456633 , ...,  -6.9938755 ,
         -5.226178  ,  -6.4829054 ],
       ...,
       [ -3.2562125 ,  -7.418148  ,  -6.434478  , ...,  -4.410464  ,
         -3.2893934 ,  -3.2115388 ],
       [ -8.660235  ,  -9.120533  ,  -3.7959714 , ...,  -5.3368673 ,
         -5.125876  ,  -2.7031593 ],
       [ -7.2067266 ,  -6.161022  ,  -3.981381  , ...,  -6.703995  ,
         -2.45445   ,  -6.2313595 ]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52, 

Train set: Avg loss: 0.490: 100%|██████████| 118/118 [00:03<00:00, 32.96it/s]


Tensor(array([[-6.5432816e+00, -3.7369554e+00, -5.5403104e+00, ...,
        -2.6674225e+00, -2.0150948e+00, -4.7951245e-01],
       [-1.3948441e-02, -1.4953245e+01, -9.1166058e+00, ...,
        -6.2809916e+00, -7.9020472e+00, -7.6647444e+00],
       [-6.4198050e+00, -9.0718670e+00, -8.3805628e+00, ...,
        -1.6994905e-01, -6.4838448e+00, -2.0744891e+00],
       ...,
       [-7.0827808e+00, -6.4825444e+00, -4.8151555e+00, ...,
        -1.1219978e-01, -4.1519985e+00, -2.6908896e+00],
       [-3.7279129e-03, -1.7952869e+01, -8.3960037e+00, ...,
        -9.6356049e+00, -8.9135857e+00, -8.9864063e+00],
       [-3.7821047e+00, -1.2628450e+01, -6.4597569e+00, ...,
        -4.0909181e+00, -3.6952314e+00, -2.6010084e-01]], dtype=float32), requires_grad=True) (Tensor(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39, 

In [39]:
line(
    [train_loss_list, test_accuracy_list],
    x_max=num_epochs,
    yaxis2_range=[0, 1],
    use_secondary_yaxis=True,
    labels={"x": "Batches seen", "y1": "Cross entropy loss", "y2": "Test accuracy"},
    title="MLP training on MNIST from scratch!",
    width=800,
)

Note - this training loop (if done correctly) will look to the one we used in earlier sections is that we're using SGD rather than Adam. You can try adapting your Adam code from the previous day's exercises, and get the same results as you have in earlier sections.

If it works then congratulations - you've implemented a fully-functional autograd system!

# Bonus

Congratulations on finishing the day's main content! Here are a few more bonus things for you to explore.

### In-Place Operation Warnings

The most severe issue with our current system is that it can silently compute the wrong gradients when in-place operations are used. Have a look at how [PyTorch handles it](https://pytorch.org/docs/stable/notes/autograd.html#in-place-operations-with-autograd) and implement a similar system yourself so that it either computes the right gradients, or raises a warning.

### In-Place `ReLU`

Instead of implementing ReLU in terms of maximum, implement your own forward and backward functions that support `inplace=True`.

### Backward for `einsum`

Write the backward pass for your equivalent of `torch.einsum`.

### Reuse of Module during forward

Consider the following MLP, where the same `nn.ReLU` instance is used twice in the forward pass. Without running the code, explain whether this works correctly or not with reference to the specifics of your implementation.

```python
class MyModule(Module):
    def __init__(self):
        super().__init__()
        self.linear1 = Linear(28*28, 64)
        self.linear2 = Linear(64, 64)
        self.linear3 = Linear(64, 10)
        self.relu = ReLU()
    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        return self.linear3(x)
```

<details>
<summary>Answer (what you should find)</summary>

This implementation will work correctly.

The danger of reusing modules is that you'd be creating a cyclical computational graph (because the same parameters would appear twice), but the `ReLU` module doesn't have any parameters (or any internal state), so this isn't a problem. It's effectively just a wrapper for the `relu` function, and you could replace `self.relu` with applying the `relu` function directly without changing the model's behaviour.

This is slightly different if we're thinking about adding **hooks** to our model. Hooks are functions that are called during the forward or backward pass, and they can be used to inspect the state of the model during training. We generally want each hook to be associated with a single position in the model, rather than being called at two different points.
</details>

### Convolutional layers

Now that you've implemented a linear layer, it should be relatively straightforward to take your convolutions code from day 2 and use it to make a convolutional layer. How much better performance do you get on the MNIST task once you replace your first two linear layers with convolutions?

### ResNet Support

Make a list of the features that would need to be implemented to support ResNet inference, and training. It will probably take too long to do all of them, but pick some interesting features to start implementing.

### Central Difference Checking

Write a function that compares the gradients from your backprop to a central difference method. See [Wikipedia](https://en.wikipedia.org/wiki/Finite_difference) for more details.

### Non-Differentiable Function Support

Your `Tensor` does not currently support equivalents of `torch.all`, `torch.any`, `torch.floor`, `torch.less`, etc. which are non-differentiable functions of Tensors. Implement them so that they are usable in computational graphs, but gradients shouldn't flow through them (their contribution is zero).

### Differentiation wrt Keyword Arguments

In the real PyTorch, you can sometimes pass tensors as keyword arguments and differentiation will work, as in `t.add(other=t.tensor([3,4]), input=t.tensor([1,2]))`. In other similar looking cases like `t.dot`, it raises an error that the argument must be passed positionally. Decide on a desired behavior in your system and implement and test it.

### `torch.stack`

So far we've registered a separate backwards for each input argument that could be a Tensor. This is problematic if the function can take any number of tensors like `torch.stack` or `numpy.stack`. Think of and implement the backward function for stack. It may require modification to your other code.