In [0]:
! pip install fastai2 utils python-mnist

In [0]:
from fastai2.vision.all import *
from utils import *
from mnist import MNIST
import numpy as np

In [0]:
# Download data
! wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
! wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
! wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
! wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
! mkdir samples;mv t*-ubyte.gz samples;cd samples;gunzip t*-ubyte.gz

In [0]:
# Load data
mndata = MNIST('samples')
images, labels = mndata.load_training()
print(len(images[0]),labels[0])

784 5


In [0]:
# One-hot label
nb_classes = 10
targets = np.array(labels).reshape(-1)
train_y =  torch.tensor(np.eye(nb_classes)[targets]).float()
train_y.shape

torch.Size([60000, 10])

In [0]:
# Normalize the data to keep gradients manageable
train_x = tensor(images).float()/255
train_x.shape

torch.Size([60000, 784])

In [0]:
# Move data to mini batch size
dset = list(zip(train_x[:50000],train_y[:50000]))
dl = DataLoader(dset, batch_size=200)

valid_dset = list(zip(train_x[50000:],train_y[50000:]))
valid_dl = DataLoader(valid_dset, batch_size=200)

In [0]:
def init_params(size, std=1.0):
    return (torch.randn(size)*std).requires_grad_()

# Although I written my own softmax function however that has some NaN value problems, I'm using nn.Softmax replace it.
def softmax(layer_outputs):
    softmax_outs = []
    E = tensor([2.71828182846]) # calculate the exponential value
    
    for i in layer_outputs:
        # normalize values
        exp_values = E ** i
        normalize_base = torch.sum(exp_values)
        print(exp_values / normalize_base)
        softmax_outs.append(exp_values / normalize_base)
    return softmax_outs

def simple_net(xb): 
    res = xb@w1 + b1
    res = res.max(tensor(0.0))
    res = res@w2 + b2
    res = res.max(tensor(0.0))
    m = nn.Softmax(dim=1)
    return m(res)

def mnist_loss(softmax_outputs,target_outputs):
    loss = tensor([0.])
    # cross-entropy
    for i in range(len(softmax_outputs)):
        eps=1e-7
        loss += -(torch.log(softmax_outputs[i] @ target_outputs[i]+eps))

    return loss/len(softmax_outputs)

def calc_grad(xb, yb, model):
    preds = model(xb)
    loss = mnist_loss(preds, yb)
    #print(loss)
    loss.backward()

def train_epoch(model, lr, params):
    for xb,yb in dl:
        calc_grad(xb, yb, model)
        for p in params:
            p.data -= p.grad*lr
            p.grad.zero_()

def batch_accuracy(xb, yb):
    preds = simple_net(xb)
    xy_set = list(zip(preds,yb))
    pred_rights = 0

    for x,y in xy_set:
        #print(torch.argmax(x),torch.argmax(y))
        if torch.argmax(x) == torch.argmax(y):
            pred_rights += 1
        
    return tensor(pred_rights / len(xy_set))

def validate_epoch():
    accs = [batch_accuracy(xb, yb) for xb,yb in valid_dl]
    return round(torch.stack(accs).mean().item(), 4)

In [0]:
w1 = init_params((28*28,30))
b1 = init_params(30)
w2 = init_params((30,10))
b2 = init_params(10)

In [0]:
lr = 1.
params = w1,b1,w2,b2
train_epoch(simple_net, lr, params)
validate_epoch()

0.1125

In [0]:
for i in range(20):
    train_epoch(simple_net, lr, params)
    print(validate_epoch())

0.1412
0.1726
0.1728
0.1779
0.1782
0.1919
0.2566
0.2538
0.2524
0.2571
0.2581
0.2581
0.3308
0.3439
0.3627
0.3688
0.435
0.4237
0.4773
0.4789


In [0]:
# for xb,yb in dl:
#     res = xb@w1 + b1
#     res = res.max(tensor(0.0))
#     res = res@w2 + b2
#     res = res.max(tensor(0.0))
#     s = softmax(res)
#     #print(s)
#     # kkk = mnist_loss(s,yb)
#     # print(kkk)
#     break

## References

https://www.fast.ai/

Data Source: http://yann.lecun.com/exdb/mnist/
Y. LeCun and C. Cortes. Mnist handwritten digit database. AT&T Labs [Online]. 

[Application of Crack Identification Techniques for an Aging Concrete Bridge Inspection Using an Unmanned Aerial Vehicle](https://www.researchgate.net/publication/325663483_Application_of_Crack_Identification_Techniques_for_an_Aging_Concrete_Bridge_Inspection_Using_an_Unmanned_Aerial_Vehicle)

[Extract images from idx3](https://stackoverflow.com/questions/40427435/extract-images-from-idx3-ubyte-file-or-gzip-via-python)


https://discuss.pytorch.org/t/how-to-solve-the-loss-become-nan-because-of-using-torch-log/54499