# Setup 
Imports, installations, data acquisition

<BR>

In [None]:
!pip install -Uqq fastbook

In [1]:
import fastbook
fastbook.setup_book()

In [2]:
from fastai.vision.all import *
from fastbook import *

matplotlib.rc('image', cmap='Greys')

In [3]:
path = untar_data(URLs.MNIST)

In [4]:
path

Path('/Users/yvesborenstein/.fastai/data/mnist_png')

In [5]:
Path.BASE_PATH = path

In [6]:
path.ls()

(#3) [Path('.DS_Store'),Path('training'),Path('testing')]

# Setup training and validation sets
Also, sanity checks at each step, indented for legibility

<BR>

In [7]:
def make_L(digit, folder_name):
    return (path/folder_name/str(digit)).ls()

In [92]:
        threes = make_L(3, 'training')
        threes

(#6131) [Path('training/3/49081.png'),Path('training/3/51816.png'),Path('training/3/39566.png'),Path('training/3/24251.png'),Path('training/3/20989.png'),Path('training/3/29013.png'),Path('training/3/58832.png'),Path('training/3/9294.png'),Path('training/3/20023.png'),Path('training/3/36899.png')...]

In [93]:
        pre_training_digits = [make_L(i, 'training') for i in range(10)]
        pre_training_digits[3]

(#6131) [Path('training/3/49081.png'),Path('training/3/51816.png'),Path('training/3/39566.png'),Path('training/3/24251.png'),Path('training/3/20989.png'),Path('training/3/29013.png'),Path('training/3/58832.png'),Path('training/3/9294.png'),Path('training/3/20023.png'),Path('training/3/36899.png')...]

In [10]:
training_digits = [ torch.stack([tensor(Image.open(o)) for o in make_L(digit, 'training')]).float()/255 for digit in range(10) ]
validation_digits = [ torch.stack([tensor(Image.open(o)) for o in make_L(digit, 'testing')]).float()/255 for digit in range(10) ]

In [91]:
        training_digits[0].shape, validation_digits[0].shape

(torch.Size([5923, 28, 28]), torch.Size([980, 28, 28]))

In [11]:
training_x   = torch.cat((training_digits)).view(-1, 28*28)
training_y   = torch.cat([tensor([digit] * training_digits[digit].shape[0]) for digit in range(10)]).unsqueeze(1)

validation_x = torch.cat((validation_digits)).view(-1, 28*28)
validation_y = torch.cat([tensor([digit] * validation_digits[digit].shape[0]) for digit in range(10)]).unsqueeze(1)

In [90]:
        training_x.shape, training_y.shape, validation_x.shape, validation_y.shape

(torch.Size([60000, 784]),
 torch.Size([60000, 1]),
 torch.Size([10000, 784]),
 torch.Size([10000, 1]))

In [13]:
training_dset   = list(zip(training_x,training_y))
validation_dset = list(zip(validation_x,validation_y))

In [89]:
        x,y = training_dset[0]
        x.shape,y

(torch.Size([784]), tensor([0]))

In [22]:
training_dl   = DataLoader(training_dset, batch_size=256)
validation_dl = DataLoader(validation_dset, batch_size=256)

In [23]:
dls = DataLoaders(training_dl, validation_dl)

# Define Loss function and batch accuracy

In [94]:
def mnist_loss(predictions, targets):
    
#     make predictions output into probabilities,i.e numbers between 0 & 1, all summing up to 1
    predictions = nn.Softmax(dim=1)(predictions)
    
#     replace the target-eth digit by 1 minus itself, for each image
    targets = targets.long().unsqueeze(1)
    scores = predictions.scatter(1,targets, 1-torch.gather(predictions,1,targets))
    
#     Sum of top 2 explanation & justification:
#     Note that only 1 or 2 of the values in predictions/scores are of interest: 1 - prediction value
#     for the correct digit, and the largest prediction value in the case of a wrong prediction. These should be the 2 largest 
#     of the 10 values. So we take the 2 largest and sum them. Note that if we had taken the sum of all 10 values, 
#     we would actually be losing information: due to the softmax/probability property of all values summing to 1, 
#     throwing in a 1 - pred instead of pred for the correct coordinate, actually only speaks about that coordinate's value.
#     Here, I also want to take into account the second highest value. Furthormore, this accentuates the ratio of 
#     average loss(incorrect predictions/ average loss(correct prediction), compared with taking the sum of all ten.
#     This way, the prediction is doubly penalized when incorrect: once for incorrectly identifying the wrong digit, 
#     and once for failing to identify the right digit. And it is only lightly penalized when predicting correctly, but still penalizes 
#     having a second high probability. The aim is to drive the model to predict correctly and be sure of its predictions.
    return torch.topk(scores,2,1)[0].sum(1).mean()

In [87]:
        trgts  = tensor([2,6,1])
        prds   = tensor([1.2, 0.6, 8.1, 5.1, 0.6, 3.4, 4.2, 5.34, 3.33, 2.3], 
                        [3.5, 4.3, 5.5, 2.4, 3.45, 6.7, 10.6, 2.86, 3.67, 2.98], 
                        [1.1, 6.5, 1.3, 1.8, 3.4, 2.6, 4.5, 2.65, 3.22, 2.2])
        nn.Softmax(dim=1)(prds)

tensor([[8.7176e-04, 4.7843e-04, 8.6502e-01, 4.3067e-02, 4.7843e-04, 7.8676e-03, 1.7510e-02, 5.4749e-02, 7.3357e-03, 2.6189e-03],
        [7.9955e-04, 1.7794e-03, 5.9079e-03, 2.6615e-04, 7.6055e-04, 1.9615e-02, 9.6903e-01, 4.2160e-04, 9.4771e-04, 4.7535e-04],
        [3.4952e-03, 7.7386e-01, 4.2690e-03, 7.0384e-03, 3.4862e-02, 1.5664e-02, 1.0473e-01, 1.6467e-02, 2.9119e-02, 1.0500e-02]])

In [88]:
        mnist_loss(prds, trgts)

tensor(0.1904)

In [25]:
def batch_accuracy(xb, yb):
    
    preds = nn.Softmax(dim=1)(xb)
    correct = (torch.max(preds, dim=1)[1] == yb)
    return correct.float().mean()

# Define model
Copied on MNIST_Basics notebook, except output is 10-fold instead of 1-fold

<BR>

In [26]:
simple_net = nn.Sequential(
    nn.Linear(28*28,30),
    nn.ReLU(),
    nn.Linear(30,10)
)

In [27]:
learn = Learner(dls, simple_net, opt_func=SGD,
                loss_func=mnist_loss, metrics=batch_accuracy)

In [95]:
learn.lr_find()

# nasty error

RuntimeError: Index tensor must have the same number of dimensions as input tensor

In [None]:
lr_min = ???
lr_steep = ???
learn.fit_one_cycle(20, slice(lr_min,lr_steep))

# Ideas for Improvements

### Data Augmentations:
* translation by 1,2, or 3 pixels 
* symetries about axes [that's what I wrote before Alex's comment on Symetries]
* rotations

### Different architectures
* more layers
* use convolutional layers
* different activations 
