In [2]:
import numpy as np
import torch 
import torch.nn as nn 

In [3]:
conv = nn.Conv1d(1, 1, 3, bias=False)
conv.weight.data = torch.tensor(np.array([1., 2., 3.]), dtype=torch.float32).view(conv.weight.shape)

x = torch.tensor(np.ones((1, 1, 10)), dtype=torch.float32, requires_grad=True)
y = conv(x)
y.sum().backward()
#x.grad

In [4]:
# 1D

N = 10
inp = np.arange(N)
pad = 1
x = np.pad(inp, (pad, pad))

n = x.shape[0]
k = 3
nout = n - k + 1
s = x.strides[0]

u = np.lib.stride_tricks.as_strided(x, (nout, k), (s, s))
w = np.array([1, 1, 1])

y = u @ w.reshape(-1, 1).squeeze()

print(y)
print(np.convolve(x, w, mode='valid'))


[ 1  3  6  9 12 15 18 21 24 17]
[ 1  3  6  9 12 15 18 21 24 17]


In [3]:
from numpy.lib.stride_tricks import as_strided, sliding_window_view # <3
from collections import Counter

c = 3           # in channels
b = 64          # batches
h, w = 28, 28   # height, width (for 1D set h=channels, w=length)
ph, pw = 1, 1   # padding       (for 1D set ph=0, pw=padding)
cout = 3        # out channels
kh, kw = 3, 3   # filter size   (for 1D set kh=channels)

# only needed for debugging purposes
wout = w + 2*pw - kw + 1
hout = h + 2*ph - kh + 1
print("out: ", hout, wout)

inp = np.arange(b*c*h*w).reshape(b, c, h, w)
inp = np.random.randn(b, c, h, w)
x = np.pad(inp, ((0, 0), (0, 0), (ph, ph), (pw, pw)))
# display(x)

u = sliding_window_view(x, (b, c, kh, kw)) # get all subarrays (essentially free operation, only uses stride tricks)
assert u.shape == (1, 1, hout, wout) + (b, c, kh, kw) # can index into each subarray nicely, but we just want to flatten everything
u = u.reshape(-1, b, c*kh*kw) # flatten each subarray, except along batch dim
u = u.swapaxes(0, 1)      # put batch dim as first dim, for matmul later. 2nd dim contains each subarray 
assert u.shape == (b, hout*wout, c*kh*kw)

filtr = np.ones((cout, c, kh, kw)) # filter, given as input to function
filtr = np.random.randn(cout, c, kh, kw)

w = filtr.reshape((cout, c*kh*kw)).T
assert w.shape == (c*kh*kw, cout)

y = u @ w  # do the actual computation (!!)
print(y.shape)

out = y.swapaxes(2, 1).reshape(b, cout, hout, wout) # unravel image and put channel dim before hout,wout dims
display(out.shape)

conv = nn.Conv2d(c, cout, (kw, kh), stride=1, padding=(ph, pw), bias=False)
conv.weight.data = torch.tensor(filtr, dtype=torch.float32, requires_grad=True)
x = torch.tensor(inp, dtype=torch.float32, requires_grad=True)
y = conv(x)

yt = y.data.numpy()
ym = out.astype(np.float32)
yts = np.array(list(sorted(set(yt.flatten().round(2)))))
yms = np.array(list(sorted(set(ym.flatten().round(2)))))

print("same: ", np.allclose(yt, ym, atol=1e-4))
print("same set: ",  np.allclose(yts, yms, atol=1e-4))

out:  28 28
(64, 784, 3)


(64, 3, 28, 28)

same:  True
same set:  True


In [9]:
from numpy.lib.stride_tricks import sliding_window_view

def pad2D(x, ph, pw):
    """Pads if p > 0, crops if p < 0. `x` is assumed to have shape (b, c, h, w).

    Both padding and cropping is symmetric, so output will have shape (b, c, h+2*ph, w+2*pw)"""
    assert x.ndim == 4 
    if ph < 0: x = x[:, :, -ph:ph, :]
    if pw < 0: x = x[:, :, :, -pw:pw]
    return np.pad(x, ((0, 0), (0, 0), (max(0,ph), max(0,ph)), (max(0,pw), max(0,pw))))

# like 10x slower than torch.nn.Conv2D :(
def convolve2D(inp: np.ndarray, filtr: np.ndarray, padding: tuple[int,int] = (0, 0)) -> np.ndarray:
    assert inp.ndim == filtr.ndim == 4, (inp.ndim, filtr.ndim)
    (b, c, h, w) = inp.shape 
    (cout, cprime, kh, kw) = filtr.shape 
    assert cprime == c, (inp.shape, filtr.shape)
    
    ph, pw = padding
    x = pad2D(inp, ph, pw)

    # only needed for debugging purposes
    wout = w + 2*pw - kw + 1
    hout = h + 2*ph - kh + 1
    # print("out: ", hout, wout)

    u = sliding_window_view(x, (b, c, kh, kw)) # get all subarrays (essentially free operation, only uses stride tricks)
    assert u.shape == (1, 1, hout, wout) + (b, c, kh, kw)
    u = u.reshape(hout*wout, b, c*kh*kw) # flatten each subarray, except along batch dim (TODO: this is slow for some reason)
    u = u.swapaxes(0, 1) # put batch dim as first dim, for matmul later. 2nd dim contains each subarray 
    assert u.shape == (b, hout*wout, c*kh*kw)

    w = filtr.reshape((cout, c*kh*kw)).T
    assert w.shape == (c*kh*kw, cout)

    y = u @ w # do the actual computation

    assert y.shape == (b, hout*wout, cout)
    out = y.swapaxes(1, 2).reshape(b, cout, hout, wout)
    return out

c = 3           # in channels
b = 10          # batches
h, w = 28, 28   # height, width (for 1D set h=channels, w=length)
ph, pw = 1,1   # padding       (for 1D set ph=0, pw=padding)
cout = 3        # out channels
kh, kw = 3, 3   # filter size   (for 1D set kh=channels)

inp = np.random.randn(b, c, h, w)
filtr = np.random.randn(cout, c, kh, kw)
out = convolve2D(inp, filtr, (ph, pw))

print(inp.shape, out.shape)

# simulate gradient 
gy = np.ones_like(out) 

# calculate padding 
pyh = kh - ph - 1
pyw = kw - pw - 1
print(pyh, pyw)

filtr_rev = np.flip(filtr, (-1, -2)).swapaxes(0, 1)
print(gy.shape, filtr_rev.shape)
gx = convolve2D(gy, filtr_rev, (pyh, pyw))
print(gx.shape)

assert gx.shape == inp.shape

x = pad2D(inp, ph, pw).swapaxes(0,1)
y = gy.swapaxes(0,1)
print(x.shape, y.shape)

gw = convolve2D(pad2D(inp, ph, pw).swapaxes(0,1), gy.swapaxes(0,1), (0,0)).swapaxes(0,1)

# gw = convolve2D(x, y, (0, 0)).swapaxes(0,1)
assert gw.shape == filtr.shape

(10, 3, 28, 28) (10, 3, 28, 28)
1 1
(10, 3, 28, 28) (3, 3, 3, 3)
(10, 3, 28, 28)
(3, 10, 30, 30) (3, 10, 28, 28)


In [10]:
# pytorch equiv
conv = nn.Conv2d(c, cout, (kw, kh), stride=1, padding=(ph, pw), bias=False)
conv.weight.data = torch.tensor(filtr, dtype=torch.float32, requires_grad=True)
x = torch.tensor(inp, dtype=torch.float32, requires_grad=True)
y = conv(x)
yt = y.data.numpy()
ym = out 

print("same fwd:", np.allclose(yt, ym, atol=1e-4), (yt - ym).max())

y.sum().backward()
gxt = x.grad.data.numpy() 
gxm = gx
gwt = conv.weight.grad.data.numpy()
gwm = gw

print("same bwd gx:", np.allclose(gxt, gxm, atol=1e-4), np.abs(gxt - gxm).max())
print("same bwd gw:", np.allclose(gwt, gwm, atol=1e-2), np.abs(gwt - gwm).max())

same fwd: True 3.8491405724983e-06
same bwd gx: True 7.105834693987845e-07
same bwd gw: True 7.957397089342066e-05


In [34]:
from micrograd.nn import Conv2D 
from micrograd.core import Tensor
from micrograd.conv import convolve2D, pad2D

c = 3           # in channels
b = 10          # batches
h, w = 28, 28   # height, width (for 1D set h=channels, w=length)
ph, pw = 1,1   # padding       (for 1D set ph=0, pw=padding)
cout = 3        # out channels
kh, kw = 3, 3   # filter size   (for 1D set kh=channels)

x = Tensor(np.random.randn(b, c, h, w))
conv = Conv2D(c, cout, (kh, kw), (ph, pw))
y = conv(x)
y.sum().backward()

xt = torch.tensor(x.data, requires_grad=True)
convt = nn.Conv2d(c, cout, (kh, kw), 1, (ph, pw), bias=False)
convt.weight.data = torch.tensor(conv.w.data)
yt = convt(xt)
yt.sum().backward()

assert np.allclose(y.data, yt.data.numpy(), atol=1e-5)
assert np.allclose(x.grad.data, xt.grad.data.numpy(), atol=1e-4)
assert np.allclose(conv.w.grad.data, convt.weight.grad.data.numpy(), atol=1e-4)

In [289]:
# max pooling:

c = 2           # in channels
b = 1           # batches
h, w = 5, 5 # height, width 
ph, pw = 0, 0   # padding     
kh, kw = 3, 3   # filter size 

def pool2d(x, kersize: tuple[int,int], ):
    b, c, h, w = x.shape 
    kh, kw = kersize
    hout = h + 2*ph - kh + 1
    wout = w + 2*pw - kw + 1
    # print(hout, wout)

    u = sliding_window_view(x, (b, c, kh, kw)).squeeze((0,1))

    y = u.max((-1, -2)).transpose(2,3,0,1)

    uflat = u.reshape(hout, wout, b, c, -1)

    idx = np.argmax(uflat, -1).transpose(2, 3, 0, 1)
    ys, xs = np.unravel_index(idx, (kh, kw))
    ys = ys + np.arange(hout).reshape(-1,1)
    xs = xs + np.arange(wout).reshape(1,-1)

    linidx = (w*ys + xs).reshape(b, c, -1)

    gx = np.zeros((b, c, w*h))
    gy = np.ones_like(y).reshape(b, c, -1)

    it = np.nditer(linidx, flags=['multi_index'])
    for idx in it:
        i, j, k, = it.multi_index
        gx[i, j, idx] += gy[i,j,k]

    gx = gx.reshape(b,c,h,w)

    return y, gx


x = np.arange(b*c*h*w).reshape(b,c,h,w)
x = x.ravel(); np.random.shuffle(x); x = x.reshape(b, c, h, w)
display(x)

y, gx = pool2d(x, (kh, kw))

array([[[[40,  9, 18,  5, 48],
         [10, 28, 13, 22,  7],
         [19, 25, 47, 20, 26],
         [ 4, 29, 38, 30, 16],
         [15,  1,  6, 34, 33]],

        [[39, 12, 11, 35, 31],
         [24, 43, 46, 21, 27],
         [45, 23,  2, 37,  8],
         [ 0, 17, 32, 49, 42],
         [44, 36, 41, 14,  3]]]])

In [287]:
pool = nn.MaxPool2d((kh, kw), 1)
xt = torch.tensor(x, dtype=torch.float32, requires_grad=True)
yt = pool(xt)
yt.sum().backward(retain_graph=True)

xt.grad
assert np.allclose(yt.data.numpy(), y)
assert np.allclose(xt.grad.data.numpy(), gx)

In [299]:
a = np.array([[1,1,2,3,1,3,2]])
np.unique(a, axis=None, return_index=True)

(array([1, 2, 3]), array([0, 2, 3]))