In [None]:
%install '.package(path: "$cwd/FastaiNotebooks")' FastaiNotebooks

Installing packages:
	.package(path: "/home/ubuntu/fastai_docs/dev_swift/FastaiNotebooks")
		FastaiNotebooks
With SwiftPM flags: []
Working in: /tmp/tmpiqlvfyo4
Fetching https://github.com/mxcl/Path.swift
Fetching https://github.com/JustHTTP/Just
Completed resolution in 2.85s
Cloning https://github.com/mxcl/Path.swift
Resolving https://github.com/mxcl/Path.swift at 0.16.2
Cloning https://github.com/JustHTTP/Just
Resolving https://github.com/JustHTTP/Just at 0.7.1
Compile Swift Module 'Just' (1 sources)
Compile Swift Module 'Path' (9 sources)
Compile Swift Module 'FastaiNotebooks' (4 sources)
Compile Swift Module 'jupyterInstalledPackages' (1 sources)
Linking ./.build/x86_64-unknown-linux/debug/libjupyterInstalledPackages.so
Initializing Swift...
Loading library...
Installation complete!


# Implement Callback Mechanism

In [None]:
import TensorFlow

struct DataBatch<Inputs: Differentiable & TensorGroup, Labels: TensorGroup>: TensorGroup {
    var xb: Inputs
    var yb: Labels    
}

enum CallbackResult {
    case proceed
    case skip
    case stop
}

enum CallbackEvent {
    // I haven't implemented all the events.
    case beginFit
    case beginEpoch
    case beginBatch
    case afterForwardsBackwards
}

class Callback<Opt: Optimizer, Labels: TensorGroup>
where Opt.Model.CotangentVector == Opt.Model.AllDifferentiableVariables,
      Opt.Model.Input: TensorGroup
{
    func apply(event: CallbackEvent, learner: Learner<Opt, Labels>) -> CallbackResult {
        return .proceed
    }
}

class Learner<Opt: Optimizer, Labels: TensorGroup>
where Opt.Model.CotangentVector == Opt.Model.AllDifferentiableVariables,
      Opt.Model.Input: TensorGroup
{
    typealias Model = Opt.Model
    var model: Model
    
    typealias Inputs = Model.Input
    
    // I'm getting some crashes in AD-generated code if I put a `lossFunc` in the learner.
    // So I'm putting a `lossWithGradient` for now, to work around this.
    // (model, context, inputs, labels) -> (loss, grad)
    typealias LossWithGradient = (Model, Context, Inputs, Labels) -> (Tensor<Float>, Model.AllDifferentiableVariables)
    var lossWithGradient: LossWithGradient
    
    var optimizer: Opt
    
    typealias Data = Dataset<DataBatch<Inputs, Labels>>
    var data: Data

    var callbacks: [Callback<Opt, Labels>]
    
    var loss: Tensor<Float> = Tensor(0)
    var grad: Model.AllDifferentiableVariables = Model.AllDifferentiableVariables.zero
    
    var epoch: Int = 0
    var epochs: Int = 0
    
    init(
        model: Model,
        lossWithGradient: @escaping LossWithGradient,
        optimizer: Opt,
        data: Data,
        callbacks: [Callback<Opt, Labels>]
    ) {
        self.model = model
        self.lossWithGradient = lossWithGradient
        self.optimizer = optimizer
        self.data = data
        self.callbacks = callbacks
    }
    
    private func resetPerBatchValues() {
        self.loss = Tensor(0)
        self.grad = Model.AllDifferentiableVariables.zero        
    }
    
    func trainOneBatch(xb: Inputs, yb: Labels) -> CallbackResult {
        var cbResult = runCallbacks(event: .beginBatch)
        if cbResult != .proceed {
            return cbResult
        }
        let context = Context(learningPhase: .training)
        (self.loss, self.grad) = lossWithGradient(model, context, xb, yb)
        defer {
            // Zero out the loss & gradient to ensure stale values aren't used.
            resetPerBatchValues()
        }
        cbResult = runCallbacks(event: .afterForwardsBackwards)
        if cbResult != .proceed {
            return cbResult
        }
        optimizer.update(&model.allDifferentiableVariables, along: self.grad)
        return .proceed
    }
    
    func trainOneEpoch() -> CallbackResult {
        switch runCallbacks(event: .beginEpoch) {
            case .stop: return .stop
            case .skip:
                print("Unexpected .skip returned from running callbacks(event: .beginEpoch)")
                return .skip
            case .proceed: break
        }
        for batch in self.data {
            let cbResult = trainOneBatch(xb: batch.xb, yb: batch.yb)
            if cbResult != .proceed {
                return cbResult
            }
        }
        return .proceed
    }

    func fit(epochs: Int) {
        // I haven't implemented validation.
        self.epochs = epochs
        var cbResult = runCallbacks(event: .beginFit)
        if cbResult != .proceed {
            return
        }
        for epoch in 1...epochs {
            self.epoch = epoch
            cbResult = trainOneEpoch()
            if cbResult != .proceed {
                return
            }
        }
    }
    
    private func runCallbacks(event: CallbackEvent) -> CallbackResult {
        for callback in callbacks {
            let cbResult = callback.apply(event: event, learner: self)
            if cbResult != .proceed {
                return cbResult
            }
        }
        return .proceed
    }
}

# Implement some example callbacks

In [None]:
%include "EnableIPythonDisplay.swift"
let plt = Python.import("matplotlib.pyplot")
IPythonDisplay.shell.enable_matplotlib("inline")

class Recorder<Opt: Optimizer, Labels: TensorGroup> : Callback<Opt, Labels>
// Hmm, this boilerplate is kind of annoying.
where Opt.Model.CotangentVector == Opt.Model.AllDifferentiableVariables,
      Opt.Model.Input: TensorGroup,
      // Notice that we can add constraints so that this callback only works with certain types of learners.
      // Here, we require that the optimizer's scalar type is float so that `plt.plot` understands the
      // learning rate.
      Opt.Scalar == Float
{
         
    var losses: [Float] = []
    var lrs: [Float] = []
          
    override func apply(event: CallbackEvent, learner: Learner<Opt, Labels>) -> CallbackResult {
        switch event {
        case .beginFit:
            losses = []
            lrs = []
        case .afterForwardsBackwards:
            losses.append(learner.loss.scalar!)
            lrs.append(learner.optimizer.learningRate)
        default: break
        }
        return .proceed
    }
          
    func plotLosses() {
        plt.plot(losses)
    }
          
    func plotLrs() {
        plt.plot(lrs)
    }
}

In [None]:
class Progress<Opt: Optimizer, Labels: TensorGroup> : Callback<Opt, Labels>
// Hmm, this boilerplate is kind of annoying.
where Opt.Model.CotangentVector == Opt.Model.AllDifferentiableVariables,
      Opt.Model.Input: TensorGroup {
    override func apply(event: CallbackEvent, learner: Learner<Opt, Labels>) -> CallbackResult {
        switch event {
        case .beginEpoch:
            print("Starting new epoch: \(learner.epoch) of \(learner.epochs)")
        default: break
        }
        return .proceed
    }
}

In [None]:
class ParamScheduler<Opt: Optimizer, Labels: TensorGroup, Param> : Callback<Opt, Labels>
// Hmm, this boilerplate is kind of annoying.
where Opt.Model.CotangentVector == Opt.Model.AllDifferentiableVariables,
      Opt.Model.Input: TensorGroup,
      Opt.Model.Output: TensorGroup
{
    
    let paramKeyPath: ReferenceWritableKeyPath<Learner<Opt, Labels>, Param>
    let schedule: (Float) -> Param
    
    init(paramKeyPath: ReferenceWritableKeyPath<Learner<Opt, Labels>, Param>, schedule: @escaping (Float) -> Param) {
        self.paramKeyPath = paramKeyPath
        self.schedule = schedule
    }
          
    override func apply(event: CallbackEvent, learner: Learner<Opt, Labels>) -> CallbackResult {
        switch event {
        case .beginBatch:
            learner[keyPath: paramKeyPath] = schedule(Float(learner.epoch) / Float(learner.epochs))
        default: break
        }
        return .proceed
    }
}

# The model and data

In [None]:
import FastaiNotebooks
import Path

var (xTrain,yTrain,xValid,yValid) = loadMNIST(path: Path.home/".fastai"/"data"/"mnist_tst")

In [None]:
xTrain = xTrain.reshaped(toShape: [60000, 784])

let (n,m) = (Int(xTrain.shape[0]),Int(xTrain.shape[1]))
let c = y_train.max()+1

let nh = 50
let bs: Int32 = 64

let train_ds: Dataset<DataBatch> = Dataset(elements: DataBatch(xb: xTrain, yb: yTrain)).batched(Int64(bs))

In [None]:
let outputCount = 10

struct MyModel: Layer {
    var layer1 = Dense<Float>(inputSize: m, outputSize: nh, activation: relu)
    var layer2 = Dense<Float>(inputSize: nh, outputSize: outputCount)
    
    /// A silly non-trained parameter to show off the parameter scheduler.
    @noDerivative var sillyExtraBiasParam: Tensor<Float> = Tensor(zeros: [Int32(outputCount)])
    
    @differentiable
    func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
        return input.sequenced(in: context, through: layer1, layer2) + sillyExtraBiasParam
    }
}

var model = MyModel()

func lossWithGrad(
    model: MyModel,
    in context: Context,
    inputs: Tensor<Float>,
    labels: Tensor<Int32>
) -> (Tensor<Float>, MyModel.AllDifferentiableVariables) {
    return model.valueWithGradient { model -> Tensor<Float> in
        let predictions = model.applied(to: inputs, in: context)
        return softmaxCrossEntropy(logits: predictions, labels: labels)
    }
}

# Run the learner

In [None]:
// Some typealiases to reduce repeatedly typing types.
typealias MyOptimizer = SGD<MyModel, Float>
typealias MyLearner = Learner<MyOptimizer, Tensor<Int32>>

In [None]:
let optimizer = MyOptimizer(learningRate: 0.01)

In [None]:
// We can't schedule the learning rate because the Optimizer protocol doesn't allow setting learning rates.
// If we change it to allow setting learning rates, `ParamScheduler` should allow setting learning rates,
// with `paramKeyPath: \MyLearner.optimizer.learningRate`.
let scheduler = ParamScheduler(paramKeyPath: \MyLearner.model.sillyExtraBiasParam) { t in
    if t < 0.5 {
        return Tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    } else {
        return Tensor([10, 20, 30, 0, 0, 0, 0, 0, 0, 0])
    }
}

In [None]:
let recorder = Recorder<MyOptimizer, Tensor<Int32>>()

In [None]:
let learner = Learner(
    model: model,
    lossWithGradient: lossWithGrad,
    optimizer: optimizer,
    data: train_ds,
    callbacks: [
        Progress(),
        recorder,
        scheduler
    ])

In [None]:
learner.fit(epochs: 6)

Starting new epoch: 1 of 6
Starting new epoch: 2 of 6
Starting new epoch: 3 of 6
Starting new epoch: 4 of 6
Starting new epoch: 5 of 6
Starting new epoch: 6 of 6


In [None]:
recorder.plotLosses()
plt.show()