<a href="https://colab.research.google.com/github/8bitmp3/testing-s4tf-things/blob/master/test_fashion_mnist_gan_s4tf_feb_12_2020.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Testing FashionMNIST on a GAN

New dataset: 
https://github.com/brettkoonce/swift-models/blob/fashion-mnist/Datasets/MNIST/FashionMNIST.swift (code author: @brettkoonce)

GAN and various packages: https://github.com/tensorflow/swift-models/ 

In [0]:
import Foundation
import TensorFlow

In [0]:
// https://github.com/tensorflow/swift-models/blob/master/Support/Stderr.swift

var stderr = FileHandle.standardError

extension FileHandle: TextOutputStream {
    public func write(_ string: String) {
        guard let data = string.data(using: .utf8) else { return }
        self.write(data)
    }
}

public func printError(_ message: String) {
    print(message, to: &stderr)
}

In [0]:
// https://github.com/tensorflow/swift-models/blob/master/Support/FileManagement.swift

#if canImport(FoundationNetworking)
    import FoundationNetworking
#endif

public func createDirectoryIfMissing(at path: String) throws {
    guard !FileManager.default.fileExists(atPath: path) else { return }
    try FileManager.default.createDirectory(
        atPath: path,
        withIntermediateDirectories: true,
        attributes: nil)
}

public func download(from source: URL, to destinationDirectory: URL) throws {
    try createDirectoryIfMissing(at: destinationDirectory.path)

    let fileName = source.lastPathComponent
    let destinationFile = destinationDirectory.appendingPathComponent(fileName).path

    let downloadedFile = try Data(contentsOf: source)
    try downloadedFile.write(to: URL(fileURLWithPath: destinationFile))
}

In [0]:
// https://github.com/tensorflow/swift-models/blob/master/Datasets/DatasetUtilities.swift

#if canImport(FoundationNetworking)
    import FoundationNetworking
#endif

public enum DatasetUtilities {
    public static let currentWorkingDirectoryURL = URL(
        fileURLWithPath: FileManager.default.currentDirectoryPath)

    public static func downloadResource(
        filename: String,
        fileExtension: String,
        remoteRoot: URL,
        localStorageDirectory: URL = currentWorkingDirectoryURL
    ) -> URL {
        printError("Loading resource: \(filename)")

        let resource = ResourceDefinition(
            filename: filename,
            fileExtension: fileExtension,
            remoteRoot: remoteRoot,
            localStorageDirectory: localStorageDirectory)

        let localURL = resource.localURL

        if !FileManager.default.fileExists(atPath: localURL.path) {
            printError(
                "File does not exist locally at expected path: \(localURL.path) and must be fetched"
            )
            fetchFromRemoteAndSave(resource)
        }

        return localURL
    }

    public static func fetchResource(
        filename: String,
        fileExtension: String,
        remoteRoot: URL,
        localStorageDirectory: URL = currentWorkingDirectoryURL
    ) -> Data {
        let localURL = DatasetUtilities.downloadResource(
            filename: filename, fileExtension: fileExtension, remoteRoot: remoteRoot,
            localStorageDirectory: localStorageDirectory)

        do {
            let data = try Data(contentsOf: localURL)
            return data
        } catch {
            fatalError("Failed to contents of resource: \(localURL)")
        }
    }

    struct ResourceDefinition {
        let filename: String
        let fileExtension: String
        let remoteRoot: URL
        let localStorageDirectory: URL

        var localURL: URL {
            localStorageDirectory.appendingPathComponent(filename)
        }

        var remoteURL: URL {
            remoteRoot.appendingPathComponent(filename).appendingPathExtension(fileExtension)
        }

        var archiveURL: URL {
            localURL.appendingPathExtension(fileExtension)
        }
    }

    static func fetchFromRemoteAndSave(_ resource: ResourceDefinition) {
        let remoteLocation = resource.remoteURL
        let archiveLocation = resource.localStorageDirectory

        do {
            printError("Fetching URL: \(remoteLocation)...")
            try download(from: remoteLocation, to: archiveLocation)
        } catch {
            fatalError("Failed to fetch and save resource with error: \(error)")
        }
        printError("Archive saved to: \(archiveLocation.path)")

        extractArchive(for: resource)
    }

    static func extractArchive(for resource: ResourceDefinition) {
        printError("Extracting archive...")

        let archivePath = resource.archiveURL.path

        #if os(macOS)
            let binaryLocation = "/usr/bin/"
        #else
            let binaryLocation = "/bin/"
        #endif

        let toolName: String
        let arguments: [String]
        switch resource.fileExtension {
        case "gz":
            toolName = "gunzip"
            arguments = [archivePath]
        case "tar.gz", "tgz":
            toolName = "tar"
            arguments = ["xzf", archivePath, "-C", resource.localStorageDirectory.path]
        default:
            printError("Unable to find archiver for extension \(resource.fileExtension).")
            exit(-1)
        }
        let toolLocation = "\(binaryLocation)\(toolName)"

        let task = Process()
        task.executableURL = URL(fileURLWithPath: toolLocation)
        task.arguments = arguments
        do {
            try task.run()
            task.waitUntilExit()
        } catch {
            printError("Failed to extract \(archivePath) with error: \(error)")
            exit(-1)
        }

        if FileManager.default.fileExists(atPath: archivePath) {
            do {
                try FileManager.default.removeItem(atPath: archivePath)
            } catch {
                printError("Could not remove archive, error: \(error)")
                exit(-1)
            }
        }
    }
}

In [0]:
// https://github.com/tensorflow/swift-models/blob/master/Datasets/LabeledExample.swift

public struct LabeledExample: TensorGroup {
    public var label: Tensor<Int32>
    public var data: Tensor<Float>

    public init(label: Tensor<Int32>, data: Tensor<Float>) {
        self.label = label
        self.data = data
    }

    public init<C: RandomAccessCollection>(
        _handles: C
    ) where C.Element: _AnyTensorHandle {
        precondition(_handles.count == 2)
        let labelIndex = _handles.startIndex
        let dataIndex = _handles.index(labelIndex, offsetBy: 1)
        label = Tensor<Int32>(handle: TensorHandle<Int32>(handle: _handles[labelIndex]))
        data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
    }
}

In [0]:
// https://github.com/tensorflow/swift-models/blob/master/Datasets/ImageClassificationDataset.swift

public protocol ImageClassificationDataset {
    init()
    var trainingDataset: Dataset<LabeledExample> { get }
    var testDataset: Dataset<LabeledExample> { get }
    var trainingExampleCount: Int { get }
    var testExampleCount: Int { get }
}

In [0]:
// The FashionMNIST dataset
// https://github.com/brettkoonce/swift-models/blob/fashion-mnist/Datasets/MNIST/FashionMNIST.swift

public struct FashionMNIST: ImageClassificationDataset {
    public let trainingDataset: Dataset<LabeledExample>
    public let testDataset: Dataset<LabeledExample>
    public let trainingExampleCount = 60000
    public let testExampleCount = 10000

    public init() {
        self.init(flattening: false, normalizing: false)
    }

    public init(
        flattening: Bool = false, normalizing: Bool = false,
        localStorageDirectory: URL = FileManager.default.temporaryDirectory.appendingPathComponent(
            "FashionMNIST")
    ) {
        self.trainingDataset = Dataset<LabeledExample>(
            elements: fetchDataset(
                localStorageDirectory: localStorageDirectory,
                imagesFilename: "train-images-idx3-ubyte",
                labelsFilename: "train-labels-idx1-ubyte",
                flattening: flattening,
                normalizing: normalizing))

        self.testDataset = Dataset<LabeledExample>(
            elements: fetchDataset(
                localStorageDirectory: localStorageDirectory,
                imagesFilename: "t10k-images-idx3-ubyte",
                labelsFilename: "t10k-labels-idx1-ubyte",
                flattening: flattening,
                normalizing: normalizing))
    }
}

fileprivate func fetchDataset(
    localStorageDirectory: URL,
    imagesFilename: String,
    labelsFilename: String,
    flattening: Bool,
    normalizing: Bool
) -> LabeledExample {
    guard let remoteRoot = URL(string: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/") else {
        fatalError("Failed to create FashionMNIST root url: http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/")
    }

    let imagesData = DatasetUtilities.fetchResource(
        filename: imagesFilename,
        fileExtension: "gz",
        remoteRoot: remoteRoot,
        localStorageDirectory: localStorageDirectory)
    let labelsData = DatasetUtilities.fetchResource(
        filename: labelsFilename,
        fileExtension: "gz",
        remoteRoot: remoteRoot,
        localStorageDirectory: localStorageDirectory)

    let images = [UInt8](imagesData).dropFirst(16).map(Float.init)
    let labels = [UInt8](labelsData).dropFirst(8).map(Int32.init)

    let rowCount = labels.count
    let (imageWidth, imageHeight) = (28, 28)

    if flattening {
        var flattenedImages =
            Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images)
            / 255.0
        if normalizing {
            flattenedImages = flattenedImages * 2.0 - 1.0
        }
        return LabeledExample(label: Tensor(labels), data: flattenedImages)
    } else {
        return LabeledExample(
            label: Tensor(labels),
            data:
                Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
                .transposed(permutation: [0, 2, 3, 1]) / 255  // NHWC
        )
    }
}

In [0]:
// https://github.com/tensorflow/swift-models/blob/master/GAN/main.swift

let epochCount = 10
let batchSize = 32
let outputFolder = "./output/"
let imageHeight = 28
let imageWidth = 28
let imageSize = imageHeight * imageWidth
let latentSize = 64

// Models
struct Generator: Layer {
    var dense1 = Dense<Float>(
        inputSize: latentSize, outputSize: latentSize * 2,
        activation: { leakyRelu($0) })

    var dense2 = Dense<Float>(
        inputSize: latentSize * 2, outputSize: latentSize * 4,
        activation: { leakyRelu($0) })

    var dense3 = Dense<Float>(
        inputSize: latentSize * 4, outputSize: latentSize * 8,
        activation: { leakyRelu($0) })

    var dense4 = Dense<Float>(
        inputSize: latentSize * 8, outputSize: imageSize,
        activation: tanh)

    var batchnorm1 = BatchNorm<Float>(featureCount: latentSize * 2)
    var batchnorm2 = BatchNorm<Float>(featureCount: latentSize * 4)
    var batchnorm3 = BatchNorm<Float>(featureCount: latentSize * 8)

    @differentiable
    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
        let x1 = batchnorm1(dense1(input))
        let x2 = batchnorm2(dense2(x1))
        let x3 = batchnorm3(dense3(x2))
        return dense4(x3)
    }
}

struct Discriminator: Layer {
    var dense1 = Dense<Float>(
        inputSize: imageSize, outputSize: 256,
        activation: { leakyRelu($0) })

    var dense2 = Dense<Float>(
        inputSize: 256, outputSize: 64,
        activation: { leakyRelu($0) })

    var dense3 = Dense<Float>(
        inputSize: 64, outputSize: 16,
        activation: { leakyRelu($0) })

    var dense4 = Dense<Float>(
        inputSize: 16, outputSize: 1,
        activation: identity)

    @differentiable
    func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
        input.sequenced(through: dense1, dense2, dense3, dense4)
    }
}

// Loss functions
@differentiable
func generatorLoss(fakeLogits: Tensor<Float>) -> Tensor<Float> {
    sigmoidCrossEntropy(
        logits: fakeLogits,
        labels: Tensor(ones: fakeLogits.shape))
}

@differentiable
func discriminatorLoss(realLogits: Tensor<Float>, fakeLogits: Tensor<Float>) -> Tensor<Float> {
    let realLoss = sigmoidCrossEntropy(
        logits: realLogits,
        labels: Tensor(ones: realLogits.shape))
    let fakeLoss = sigmoidCrossEntropy(
        logits: fakeLogits,
        labels: Tensor(zeros: fakeLogits.shape))
    return realLoss + fakeLoss
}

/// Returns `size` samples of noise vector.
func sampleVector(size: Int) -> Tensor<Float> {
    Tensor(randomNormal: [size, latentSize])
}

In [9]:
// Modified from https://github.com/tensorflow/swift-models/blob/master/GAN/main.swift

let dataset = FashionMNIST(flattening: true, normalizing: true)

Loading resource: train-images-idx3-ubyte
File does not exist locally at expected path: /tmp/FashionMNIST/train-images-idx3-ubyte and must be fetched
Fetching URL: http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz...
Archive saved to: /tmp/FashionMNIST
Extracting archive...
Loading resource: train-labels-idx1-ubyte
File does not exist locally at expected path: /tmp/FashionMNIST/train-labels-idx1-ubyte and must be fetched
Fetching URL: http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz...
Archive saved to: /tmp/FashionMNIST
Extracting archive...
Loading resource: t10k-images-idx3-ubyte
File does not exist locally at expected path: /tmp/FashionMNIST/t10k-images-idx3-ubyte and must be fetched
Fetching URL: http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz...
Archive saved to: /tmp/FashionMNIST
Extracting archive...
Loading resource: t10k-labels-idx1-ubyte
File does not exist locally

In [0]:
// https://github.com/tensorflow/swift-models/blob/master/Support/Image.swift

public struct Image {
    public enum ByteOrdering {
        case bgr
        case rgb
    }

    enum ImageTensor {
        case float(data: Tensor<Float>)
        case uint8(data: Tensor<UInt8>)
    }

    let imageData: ImageTensor

    public var tensor: Tensor<Float> {
        switch self.imageData {
        case let .float(data): return data
        case let .uint8(data): return Tensor<Float>(data)
        }
    }

    public init(tensor: Tensor<UInt8>) {
        self.imageData = .uint8(data: tensor)
    }

    public init(tensor: Tensor<Float>) {
        self.imageData = .float(data: tensor)
    }

    public init(jpeg url: URL, byteOrdering: ByteOrdering = .rgb) {
        let loadedFile = _Raw.readFile(filename: StringTensor(url.absoluteString))
        let loadedJpeg = _Raw.decodeJpeg(contents: loadedFile, channels: 3, dctMethod: "")
        if byteOrdering == .bgr {
            self.imageData = .uint8(
                data: _Raw.reverse(loadedJpeg, dims: Tensor<Bool>([false, false, false, true])))
        } else {
            self.imageData = .uint8(data: loadedJpeg)
        }
    }

    public func save(to url: URL, format: _Raw.Format = .grayscale, quality: Int64 = 95) {
        let outputImageData: Tensor<UInt8>
        switch format {
        case .grayscale:
            switch self.imageData {
            case let .uint8(data): outputImageData = data
            case let .float(data):
                let lowerBound = data.min(alongAxes: [0, 1])
                let upperBound = data.max(alongAxes: [0, 1])
                let adjustedData = (data - lowerBound) * (255.0 / (upperBound - lowerBound))
                outputImageData = Tensor<UInt8>(adjustedData)
            }
        case .rgb:
            switch self.imageData {
            case let .uint8(data): outputImageData = data
            case let .float(data):
                outputImageData = Tensor<UInt8>(
                    _Raw.clipByValue(t: data, clipValueMin: Tensor(0), clipValueMax: Tensor(255)))
            }
        default:
            print("Image saving isn't supported for the format \(format).")
            exit(-1)
        }

        let encodedJpeg = _Raw.encodeJpeg(
            image: outputImageData, format: format, quality: quality, xmpMetadata: "")
        _Raw.writeFile(filename: StringTensor(url.absoluteString), contents: encodedJpeg)
    }

    public func resized(to size: (Int, Int)) -> Image {
        switch self.imageData {
        case let .uint8(data):
            return Image(
                tensor: _Raw.resizeBilinear(
                    images: Tensor<UInt8>([data]),
                    size: Tensor<Int32>([Int32(size.0), Int32(size.1)])))
        case let .float(data):
            return Image(
                tensor: _Raw.resizeBilinear(
                    images: Tensor<Float>([data]),
                    size: Tensor<Int32>([Int32(size.0), Int32(size.1)])))
        }

    }
}

public func saveImage(_ tensor: Tensor<Float>, size: (Int, Int), directory: String, name: String)
    throws
{
    try createDirectoryIfMissing(at: directory)
    let reshapedTensor = tensor.reshaped(to: [size.0, size.1, 1])
    let image = Image(tensor: reshapedTensor)
    let outputURL = URL(fileURLWithPath: "\(directory)\(name).jpg")
    image.save(to: outputURL)
}

In [11]:
// https://github.com/tensorflow/swift-models/blob/master/GAN/main.swift

var generator = Generator()
var discriminator = Discriminator()

let optG = Adam(for: generator, learningRate: 2e-4, beta1: 0.5)
let optD = Adam(for: discriminator, learningRate: 2e-4, beta1: 0.5)

// Noise vectors and plot function for testing
let testImageGridSize = 4
let testVector = sampleVector(size: testImageGridSize * testImageGridSize)

func saveImageGrid(_ testImage: Tensor<Float>, name: String) throws {
    var gridImage = testImage.reshaped(
        to: [
            testImageGridSize, testImageGridSize,
            imageHeight, imageWidth,
        ])
    // Add padding.
    gridImage = gridImage.padded(forSizes: [(0, 0), (0, 0), (1, 1), (1, 1)], with: 1)
    // Transpose to create single image.
    gridImage = gridImage.transposed(permutation: [0, 2, 1, 3])
    gridImage = gridImage.reshaped(
        to: [
            (imageHeight + 2) * testImageGridSize,
            (imageWidth + 2) * testImageGridSize,
        ])
    // Convert [-1, 1] range to [0, 1] range.
    gridImage = (gridImage + 1) / 2

    try saveImage(
        gridImage, size: (gridImage.shape[0], gridImage.shape[1]), directory: outputFolder,
        name: name)
}

print("Start training...")

// Start training loop.
for epoch in 1...epochCount {
    // Start training phase.
    Context.local.learningPhase = .training
    let trainingShuffled = dataset.trainingDataset.shuffled(
        sampleCount: dataset.trainingExampleCount, randomSeed: Int64(epoch))
    for batch in trainingShuffled.batched(batchSize) {
        // Perform alternative update.
        // Update generator.
        let vec1 = sampleVector(size: batchSize)

        let 𝛁generator = TensorFlow.gradient(at: generator) { generator -> Tensor<Float> in
            let fakeImages = generator(vec1)
            let fakeLogits = discriminator(fakeImages)
            let loss = generatorLoss(fakeLogits: fakeLogits)
            return loss
        }
        optG.update(&generator, along: 𝛁generator)

        // Update discriminator.
        let realImages = batch.data
        let vec2 = sampleVector(size: batchSize)
        let fakeImages = generator(vec2)

        let 𝛁discriminator = TensorFlow.gradient(at: discriminator) { discriminator -> Tensor<Float> in
            let realLogits = discriminator(realImages)
            let fakeLogits = discriminator(fakeImages)
            let loss = discriminatorLoss(realLogits: realLogits, fakeLogits: fakeLogits)
            return loss
        }
        optD.update(&discriminator, along: 𝛁discriminator)
    }

    // Start inference phase.
    Context.local.learningPhase = .inference
    let testImage = generator(testVector)

    do {
        try saveImageGrid(testImage, name: "epoch-\(epoch)-output")
    } catch {
        print("Could not save image grid with error: \(error)")
    }

    let lossG = generatorLoss(fakeLogits: testImage)
    print("[Epoch: \(epoch)] Loss-G: \(lossG)")
}

Start training...
[Epoch: 1] Loss-G: 1.0132707
[Epoch: 2] Loss-G: 1.0107026
[Epoch: 3] Loss-G: 1.0149908
[Epoch: 4] Loss-G: 1.0202806
[Epoch: 5] Loss-G: 1.0052707
[Epoch: 6] Loss-G: 0.9845302
[Epoch: 7] Loss-G: 0.9998964
[Epoch: 8] Loss-G: 0.99334157
[Epoch: 9] Loss-G: 0.9954661
[Epoch: 10] Loss-G: 0.9898559
