# MNIST in Swift for TensorFlow

Blog post: https://rickwierenga.com/blog/s4tf/s4tf-mnist.html

## Importing dependencies

In [0]:
import TensorFlow
import Foundation
import FoundationNetworking

## Loading MNIST

Recreating much of the MNIST API in https://github.com/tensorflow/swift-models, because it is not currently working in Colab. The team told us they are working on an update. (https://github.com/tensorflow/swift-models/issues/233)

In [0]:
public struct LabeledExample: TensorGroup {
    public var label: Tensor<Int32>
    public var data: Tensor<Float>

    public init(label: Tensor<Int32>, data: Tensor<Float>) {
        self.label = label
        self.data = data
    }

    public init<C: RandomAccessCollection>(
        _handles: C
    ) where C.Element: _AnyTensorHandle {
        precondition(_handles.count == 2)
        let labelIndex = _handles.startIndex
        let dataIndex = _handles.index(labelIndex, offsetBy: 1)
        label = Tensor<Int32>(handle: TensorHandle<Int32>(handle: _handles[labelIndex]))
        data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
    }
}

In [0]:
public struct DatasetUtilities {
    public static let currentWorkingDirectoryURL = URL(
        fileURLWithPath: FileManager.default.currentDirectoryPath)

    public static func fetchResource(
        filename: String,
        remoteRoot: URL,
        localStorageDirectory: URL = currentWorkingDirectoryURL
    ) -> Data {
        print("Loading resource: \(filename)")

        let resource = ResourceDefinition(
            filename: filename,
            remoteRoot: remoteRoot,
            localStorageDirectory: localStorageDirectory)

        let localURL = resource.localURL

        if !FileManager.default.fileExists(atPath: localURL.path) {
            print(
                "File does not exist locally at expected path: \(localURL.path) and must be fetched"
            )
            fetchFromRemoteAndSave(resource)
        }

        do {
            print("Loading local data at: \(localURL.path)")
            let data = try Data(contentsOf: localURL)
            print("Succesfully loaded resource: \(filename)")
            return data
        } catch {
            fatalError("Failed to contents of resource: \(localURL)")
        }
    }

    struct ResourceDefinition {
        let filename: String
        let remoteRoot: URL
        let localStorageDirectory: URL

        var localURL: URL {
            localStorageDirectory.appendingPathComponent(filename)
        }

        var remoteURL: URL {
            remoteRoot.appendingPathComponent(filename).appendingPathExtension("gz")
        }

        var archiveURL: URL {
            localURL.appendingPathExtension("gz")
        }
    }

    static func fetchFromRemoteAndSave(_ resource: ResourceDefinition) {
        let remoteLocation = resource.remoteURL
        let archiveLocation = resource.archiveURL

        do {
            print("Fetching URL: \(remoteLocation)...")
            let archiveData = try Data(contentsOf: remoteLocation)
            print("Writing fetched archive to: \(archiveLocation.path)")
            try archiveData.write(to: archiveLocation)
        } catch {
            fatalError("Failed to fetch and save resource with error: \(error)")
        }
        print("Archive saved to: \(archiveLocation.path)")

        extractArchive(for: resource)
    }

    static func extractArchive(for resource: ResourceDefinition) {
        print("Extracting archive...")

        let archivePath = resource.archiveURL.path

        #if os(macOS)
            let gunzipLocation = "/usr/bin/gunzip"
        #else
            let gunzipLocation = "/bin/gunzip"
        #endif

        let task = Process()
        task.executableURL = URL(fileURLWithPath: gunzipLocation)
        task.arguments = [archivePath]
        do {
            try task.run()
            task.waitUntilExit()
        } catch {
            fatalError("Failed to extract \(archivePath) with error: \(error)")
        }
    }
}

In [0]:
public struct MNIST {
    public let trainingDataset: Dataset<LabeledExample>
    public let testDataset: Dataset<LabeledExample>
    public let trainingExampleCount = 60000

    public init() {
        self.init(flattening: false, normalizing: false)
    }

    public init(
        flattening: Bool = false, normalizing: Bool = false,
        localStorageDirectory: URL = DatasetUtilities.currentWorkingDirectoryURL
    ) {
        self.trainingDataset = Dataset<LabeledExample>(
            elements: fetchDataset(
                localStorageDirectory: localStorageDirectory,
                imagesFilename: "train-images-idx3-ubyte",
                labelsFilename: "train-labels-idx1-ubyte",
                flattening: flattening,
                normalizing: normalizing))

        self.testDataset = Dataset<LabeledExample>(
            elements: fetchDataset(
                localStorageDirectory: localStorageDirectory,
                imagesFilename: "t10k-images-idx3-ubyte",
                labelsFilename: "t10k-labels-idx1-ubyte",
                flattening: flattening,
                normalizing: normalizing))
    }
}

fileprivate func fetchDataset(
    localStorageDirectory: URL,
    imagesFilename: String,
    labelsFilename: String,
    flattening: Bool,
    normalizing: Bool
) -> LabeledExample {
    guard let remoteRoot:URL = URL(string: "http://yann.lecun.com/exdb/mnist") else {
        fatalError("Failed to create MNST root url: http://yann.lecun.com/exdb/mnist")
    }

    let imagesData = DatasetUtilities.fetchResource(
        filename: imagesFilename,
        remoteRoot: remoteRoot,
        localStorageDirectory: localStorageDirectory)
    let labelsData = DatasetUtilities.fetchResource(
        filename: labelsFilename,
        remoteRoot: remoteRoot,
        localStorageDirectory: localStorageDirectory)

    let images = [UInt8](imagesData).dropFirst(16).map(Float.init)
    let labels = [UInt8](labelsData).dropFirst(8).map(Int32.init)

    let rowCount = labels.count
    let (imageWidth, imageHeight) = (28, 28)

    if flattening {
        var flattenedImages = Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images)
            / 255.0
        if normalizing {
            flattenedImages = flattenedImages * 2.0 - 1.0
        }
        return LabeledExample(label: Tensor(labels), data: flattenedImages)
    } else {
        return LabeledExample(
            label: Tensor(labels),
            data:
                Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
                .transposed(withPermutations: [0, 2, 3, 1]) / 255  // NHWC
        )
    }
}

In [5]:
let mnist = MNIST(flattening: true, normalizing: true)

Loading resource: train-images-idx3-ubyte
Loading local data at: /content/train-images-idx3-ubyte
Succesfully loaded resource: train-images-idx3-ubyte
Loading resource: train-labels-idx1-ubyte
Loading local data at: /content/train-labels-idx1-ubyte
Succesfully loaded resource: train-labels-idx1-ubyte
Loading resource: t10k-images-idx3-ubyte
Loading local data at: /content/t10k-images-idx3-ubyte
Succesfully loaded resource: t10k-images-idx3-ubyte
Loading resource: t10k-labels-idx1-ubyte
Loading local data at: /content/t10k-labels-idx1-ubyte
Succesfully loaded resource: t10k-labels-idx1-ubyte


In [0]:
let imageHeight = 28, imageWidth = 28
let imageSize = imageHeight * imageWidth

## Constructing the network

In [0]:
struct Model: Layer {
    var hiddenLayer = Dense<Float>(inputSize: imageSize, outputSize: 300, activation: relu)
    var outputLayer = Dense<Float>(inputSize: 300, outputSize: 10, activation: softmax)
    
    @differentiable
    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
        return input.sequenced(through: hiddenLayer, outputLayer)
    }
}

In [0]:
var model = Model()

## Training

In [0]:
let batchSize = 512
let testBatches = mnist.testDataset.batched(batchSize)

In [0]:
let optimizer = Adam(for: model)

In [11]:
for epoch in 1...10 {
    // Update parameters
    Context.local.learningPhase = .training
    let trainingShuffled = mnist.trainingDataset.shuffled(sampleCount: mnist.trainingExampleCount, randomSeed: Int64(epoch))
    for batch in trainingShuffled.batched(batchSize) {
        let (labels, images) = (batch.label, batch.data)
        let (_, gradients) = valueWithGradient(at: model) { model -> Tensor<Float> in
            let logits = model(images)
            return softmaxCrossEntropy(logits: logits, labels: labels)
        }
        optimizer.update(&model, along: gradients)
    }

    // Evaluate model
    Context.local.learningPhase = .inference

    var correctTrainGuessCount = 0
    var totalTrainGuessCount = 0
    for batch in mnist.trainingDataset.batched(512) {
        let (labels, images) = (batch.label, batch.data)
        let logits = model(images)
        let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
        correctTrainGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
        totalTrainGuessCount += batchSize
    }
    let trainAcc = Float(correctTrainGuessCount) / Float(totalTrainGuessCount)

    var correctValGuessCount = 0
    var totalValGuessCount = 0
    for batch in testBatches {
        let (labels, images) = (batch.label, batch.data)
        let logits = model(images)
        let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
        correctValGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
        totalValGuessCount += batchSize
    }
    let valAcc = Float(correctValGuessCount) / Float(totalValGuessCount)
    
    print("\(epoch) | Training accuracy: \(trainAcc) | Validation accuracy: \(valAcc)")
}

1 | Training accuracy: 0.89893407 | Validation accuracy: 0.88798827
2 | Training accuracy: 0.9233481 | Validation accuracy: 0.9081055
3 | Training accuracy: 0.93415654 | Validation accuracy: 0.9181641
4 | Training accuracy: 0.94264764 | Validation accuracy: 0.9238281
5 | Training accuracy: 0.9500132 | Validation accuracy: 0.9303711
6 | Training accuracy: 0.9538036 | Validation accuracy: 0.93173826
7 | Training accuracy: 0.9592161 | Validation accuracy: 0.9397461
8 | Training accuracy: 0.9612023 | Validation accuracy: 0.93837893
9 | Training accuracy: 0.96456236 | Validation accuracy: 0.94365233
10 | Training accuracy: 0.9648603 | Validation accuracy: 0.94208986
