In [1]:
import pickle, gzip, math, os, time, shutil, torch, matplotlib as mpl, numpy as np
from pathlib import Path
from torch import tensor
from fastcore.test import test_close

torch.manual_seed(42)

mpl.rcParams["image.cmap"] = "gray"
torch.set_printoptions(precision=2, linewidth=125, sci_mode=False)
np.set_printoptions(precision=2, linewidth=125)

path_data = Path("data")
path_gz = path_data / "mnist.pkl.gz"
with gzip.open(path_gz, "rb") as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])

In [2]:
x_train.shape

torch.Size([50000, 784])

In [3]:
n, m = x_train.shape

In [4]:
y_train.max() + 1

tensor(10)

In [5]:
nh = 50

In [6]:
w1 = torch.randn(m, nh)
b1 = torch.zeros(nh)
w2 = torch.randn(nh, 1)
b2 = torch.zeros(1)

In [7]:
def lin(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    return x @ w + b

In [None]:
t = lin(x_valid, w1, b1)
t.shape

torch.Size([10000, 50])

In [9]:
def relu(t: torch.Tensor) -> torch.Tensor:
    return t.clamp_min(0.0)

In [10]:
t = relu(t)

In [11]:
def model(xb: torch.Tensor):
    l1 = lin(xb, w1, b1)
    l2 = relu(l1)
    output = lin(l2, w2, b2)
    return output

In [12]:
res = model(x_valid)

In [None]:
res

tensor([[  25.75],
        [ -13.06],
        [-114.79],
        ...,
        [ -67.44],
        [ -74.48],
        [ -60.19]])

In [19]:
res.squeeze(1).shape

torch.Size([10000])

In [15]:
y_valid.shape

torch.Size([10000])

In [21]:
preds = model(x_train)

In [22]:
def mse(output: torch.Tensor, target: torch.Tensor):
    res = (output[:, 0] - target).pow(2).mean()
    return res

In [None]:
mse(preds, y_train)

tensor(4308.76)

In [None]:
test_tens = torch.randn(m, nh)
test_tens.g = "my_g"

In [None]:
print([attr for attr in dir(test_tens) if attr == "g"])

['g']


In [None]:
class ExtendedTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, data, g=None, *args, **kwargs):
        instance = torch.Tensor._make_subclass(
            cls,
            torch.as_tensor(data),
        )

In [38]:
test_tens.t().shape

torch.Size([50, 784])

In [41]:
l1 = lin(x_train, w1, b1)
l2 = relu(l1)
output = lin(l2, w2, b2)

In [None]:
output[:, 0].shape

torch.Size([50000])

In [54]:
output.squeeze()

tensor([-30.97, -99.38,   8.72,  ..., -52.12, -46.25,  -4.35])

In [None]:
diff = output[:, 0] - y_train

In [51]:
diff

tensor([-35.97, -99.38,   4.72,  ..., -60.12, -50.25, -12.35])

In [55]:
loss = diff.pow(2).mean()

In [56]:
loss

tensor(4308.76)

In [None]:
2.0 * diff[:, None]

tensor([[ -71.94],
        [-198.76],
        [   9.45],
        ...,
        [-120.23],
        [-100.50],
        [ -24.69]])

In [None]:
2.0 * diff[:, None] / x_train.shape[0]

tensor([[-0.00],
        [-0.00],
        [ 0.00],
        ...,
        [-0.00],
        [-0.00],
        [-0.00]])

In [None]:
output.g = 2.0 * diff[:, None] / x_train.shape[0]

In [86]:
output.g, output.g.shape

(tensor([[-0.00],
         [-0.00],
         [ 0.00],
         ...,
         [-0.00],
         [-0.00],
         [-0.00]]),
 torch.Size([50000, 1]))

In [90]:
inp = l2
w = w2
b = b2

In [92]:
w.shape, w.t().shape

(torch.Size([50, 1]), torch.Size([1, 50]))

In [None]:
inp.g = output.g @ w.t()

In [None]:
inp.g.shape

torch.Size([50000, 50])

In [112]:
output.g.shape, output.g.unsqueeze(1).shape

(torch.Size([50000, 1]), torch.Size([50000, 1, 1]))

In [108]:
inp.shape

torch.Size([50000, 50])

Dims are broadcasting here


In [None]:
w.g = (inp.unsqueeze(-1) * output.g.unsqueeze(1)).sum(0)

In [130]:
w.g.shape

torch.Size([50, 1])

In [119]:
output.g.sum().shape, output.g.sum(0).shape

(torch.Size([]), torch.Size([1]))

In [133]:
(l2 > 0).float() * inp.g

tensor([[     0.00,     -0.00,     -0.00,  ...,      0.00,     -0.00,      0.00],
        [     0.00,     -0.00,     -0.00,  ...,      0.00,     -0.00,      0.01],
        [    -0.00,      0.00,      0.00,  ...,     -0.00,      0.00,     -0.00],
        ...,
        [     0.00,     -0.00,     -0.00,  ...,      0.00,     -0.00,      0.00],
        [     0.00,     -0.00,     -0.00,  ...,      0.00,     -0.00,      0.00],
        [     0.00,     -0.00,     -0.00,  ...,      0.00,     -0.00,      0.00]])

In [None]:
def lin_grad(inp: torch.Tensor, out: torch.Tensor, w: torch.Tensor, b: torch.Tensor):
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1), output.g.unsqueeze(1))
    b.g = out.g.sum(0)

SyntaxError: incomplete input (3268164355.py, line 1)