In [1]:
import numpy as np 
from matplotlib.pyplot import *
from tqdm import trange

In [5]:
# Credit: Mr. George Hotz https://github.com/geohot/ai-notebooks/blob/master/mnist_from_scratch.ipynb
def fetch(url):
  import requests, gzip, os, hashlib, numpy
  fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest())
  if os.path.isfile(fp):
    with open(fp, "rb") as f:
      dat = f.read()
  else:
    with open(fp, "wb") as f:
      dat = requests.get(url).content
      f.write(dat)
  return numpy.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
X_train = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
Y_train = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:]
X_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:]

In [75]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.device('cpu')
class DehuaNet(nn.Module):
    
    def __init__(self):
        super(DehuaNet, self).__init__() 
        
        # Setting up layers 
        self.l1 = nn.Linear(784, 128) 
        self.l2 = nn.Linear(128, 10) 
        self.sm = nn.LogSoftmax(dim=0) 
    
    def forward(self, x): 
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.sm(x) 
        return x 
    

In [110]:
net = DehuaNet()
pog = torch.flatten(torch.tensor(X_train[0], dtype=torch.float))
out = net(pog)
loss_function = nn.MSELoss() 
label = torch.zeros(10)
label[Y_train[0]] = 1
loss = loss_function(out, label)

net.zero_grad()

print('l1.bias.grad before backward')
print(net.l1.bias.grad)

loss.backward()

print('l1.bias.grad after backward')
print(net.l1.bias.grad)

l1.bias.grad before backward
None
l1.bias.grad after backward
tensor([ 0.8109,  0.0000, -1.7319,  0.3373,  0.0000,  0.6716,  0.0000,  2.1138,
        -0.1788,  0.0000, -0.4555,  0.5834,  0.0000,  1.7677,  0.4669,  2.0506,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.3320,  0.0000,
         0.0000,  2.1688,  0.0000,  0.9091,  0.0000,  0.0000,  0.0000,  0.0000,
         0.2632,  0.0000, -1.0058,  0.0000, -1.6659,  0.0000,  0.1417,  0.0000,
         1.0756,  0.0000, -0.8825,  0.2219,  0.0000,  0.0000,  0.0000,  1.9283,
        -1.5145,  1.3810,  1.1389,  0.0000,  0.0000,  0.0000, -1.4677, -0.6325,
         1.0156,  0.0000, -0.6919,  0.0000,  0.8559, -0.6209,  0.0000,  0.0000,
        -1.0316,  0.0000, -2.5735,  0.4230,  0.0000,  0.0000,  0.0000, -1.8071,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.5090,
         0.0000,  0.0000,  0.0000, -1.1047,  0.8474,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.96