# MNIST in Swift for TensorFlow (ConvNet)

Blog post: https://rickwierenga.com/blog/s4tf/s4tf-mnist.html

## Importing dependencies

In [0]:
%install-location $cwd/swift-install
%install '.package(url: "https://github.com/tensorflow/swift-models", .branch("tensorflow-0.6"))' Datasets

In [0]:
import TensorFlow
import Foundation
import Datasets

In [0]:
import Python
let plt = Python.import("matplotlib.pylab")
let np = Python.import("numpy")

In [0]:
%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")

## Loading MNIST

[tensorflow/swift-models](https://github.com/tensorflow/swift-models)

In [0]:
let batchSize = 512

In [0]:
let mnist = MNIST(batchSize: batchSize, flattening: false, normalizing: true)

## Constructing the network

In [0]:
struct Model: Layer {
    var flatten1 = Flatten<Float>()
    var conv1 = Conv2D<Float>(
      filterShape: (2, 2, 1, 32),
      padding: .same,
      activation: relu
    )
    var conv2 = Conv2D<Float>(
      filterShape: (3, 3, 32, 64),
      padding: .same,
      activation: relu
    )
    var maxPooling = MaxPool2D<Float>(poolSize: (2, 2), strides: (1, 1))
    var dropout1 = Dropout<Float>(probability: 0.25)

    var flatten2 = Flatten<Float>()
    var dense1 = Dense<Float>(inputSize: 27 * 27 * 64, outputSize: 128, activation: relu)
    var dropout2 = Dropout<Float>(probability: 0.5)
    var dense2 = Dense<Float>(inputSize: 128, outputSize: 10, activation: softmax)
    
    @differentiable
    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
        return input
          .sequenced(through: conv1, conv2, maxPooling, dropout1)
          .sequenced(through: flatten2, dense1, dropout2, dense2)
    }
}

In [0]:
var model = Model()

## Training

In [0]:
let epochs = 10
var trainHistory = np.zeros(epochs)
var valHistory = np.zeros(epochs)

In [0]:
let optimizer = Adam(for: model)

In [0]:
for epoch in 0..<epochs {
    // Update parameters
    Context.local.learningPhase = .training
    for i  in 0..<(mnist.trainingSize / batchSize)+1  {
        let thisBatchSize = i * batchSize >= mnist.trainingSize ? (mnist.trainingSize - ((i - 1) * batchSize)) : batchSize
        let images = mnist.trainingImages.minibatch(at: i, batchSize: thisBatchSize)
        let labels = mnist.trainingLabels.minibatch(at: i, batchSize: thisBatchSize)
        let (_, gradients) = valueWithGradient(at: model) { model -> Tensor<Float> in
            let logits = model(images)
            return softmaxCrossEntropy(logits: logits, labels: labels)
        }
        optimizer.update(&model, along: gradients)
    }

    // Evaluate model
    Context.local.learningPhase = .inference

    var correctTrainGuessCount = 0
    var totalTrainGuessCount = 0
    for i  in 0..<(mnist.trainingSize / batchSize)+1  {
        let thisBatchSize = i * batchSize >= mnist.trainingSize ? (mnist.trainingSize - ((i - 1) * batchSize)) : batchSize
        let images = mnist.trainingImages.minibatch(at: i, batchSize: thisBatchSize)
        let labels = mnist.trainingLabels.minibatch(at: i, batchSize: thisBatchSize)
        let logits = model(images)
        let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
        correctTrainGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
        totalTrainGuessCount += thisBatchSize
    }
    let trainAcc = Float(correctTrainGuessCount) / Float(totalTrainGuessCount)
    trainHistory[epoch] = PythonObject(trainAcc)

    var correctValGuessCount = 0
    var totalValGuessCount = 0
    for i  in 0..<(mnist.testSize / batchSize)+1  {
        let thisBatchSize = i * batchSize >= mnist.testSize ? (mnist.testSize - ((i - 1) * batchSize)) : batchSize
        let images = mnist.testImages.minibatch(at: i, batchSize: thisBatchSize)
        let labels = mnist.testLabels.minibatch(at: i, batchSize: thisBatchSize)
        let logits = model(images)
        let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
        correctValGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
        totalValGuessCount += thisBatchSize
    }
    let valAcc = Float(correctValGuessCount) / Float(totalValGuessCount)
    valHistory[epoch] = PythonObject(valAcc)
    
    print("\(epoch) | Training accuracy: \(trainAcc) | Validation accuracy: \(valAcc)")
}

## Inspecting training history

In [0]:
plt.plot(trainHistory)
plt.title("Training History")
plt.show()

In [0]:
plt.plot(valHistory)
plt.title("Validation History")
plt.show()