# Callbacks version 2b

In [None]:
%install '.package(path: "$cwd/FastaiNotebook_03_minibatch_training")' FastaiNotebook_03_minibatch_training

## Load data

In [None]:
import FastaiNotebook_03_minibatch_training

In [None]:
// export
import Path
import TensorFlow

In [None]:
var (xTrain,yTrain,xValid,yValid) = loadMNIST(path: mnistPath, flat: true)

In [None]:
let (n,m) = (Int(xTrain.shape[0]),Int(xTrain.shape[1]))
let c = yTrain.max()+1
print(n,m,c)

Those can't be used to define a model cause they're not Ints though...

In [None]:
let (n,m) = (60000,784)
let c = 10
let nHid = 50

In [None]:
// export
public struct BasicModel: Layer {
    public var layer1: Dense<Float>
    public var layer2: Dense<Float>
    
    public init(nIn: Int, nHid: Int, nOut: Int){
        layer1 = Dense(inputSize: nIn, outputSize: nHid, activation: relu)
        layer2 = Dense(inputSize: nHid, outputSize: nOut)
    }
    
    @differentiable
    public func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
        return input.sequenced(in: context, through: layer1, layer2)
    }
}

In [None]:
var model = BasicModel(nIn: m, nHid: nHid, nOut: c)

In [None]:
// export
public struct DataBunch<Element> where Element: TensorGroup{
    public var train: Dataset<Element>
    public var valid: Dataset<Element>
    
    public init(train: Dataset<Element>, valid: Dataset<Element>) {
        self.train = train
        self.valid = valid
    }
}

In [None]:
//export
public func mnistDataBunch(path: Path = mnistPath, flat: Bool = false, bs: Int = 64
                          ) -> DataBunch<DataBatch<Tensor<Float>, Tensor<Int32>>>{
    let (xTrain,yTrain,xValid,yValid) = loadMNIST(path: path, flat: flat)
    return DataBunch(train: Dataset(elements:DataBatch(xb:xTrain, yb:yTrain)).batched(Int64(bs)), 
                     valid: Dataset(elements:DataBatch(xb:xValid, yb:yValid)).batched(Int64(bs)))
}

In [None]:
let data = mnistDataBunch(flat: true)

## Learner (Richard's version)

In [None]:
// export
public enum LearnerAction: Error {
    case skipEpoch
    case skipBatch
    case stop
}

Basic class

In [None]:
// export
/// A model learner, responsible for initializing and training a model on a given dataset.
// NOTE: When TF-421 is fixed, make `Label` not constrained to `Differentiable`.
public final class Learner<Label: TensorGroup,
                           O: TensorFlow.Optimizer & AnyObject>
    where O.Scalar: Differentiable,
          O.Model.Input: TensorGroup
{
    // Common type aliases.
    public typealias Input = Model.Input
    public typealias Data = DataBunch<DataBatch<Input, Label>>
    public typealias Loss = Tensor<Float>
    public typealias Optimizer = O
    public typealias Model = Optimizer.Model
    public typealias Variables = Model.AllDifferentiableVariables
    // NOTE: When TF-421 is fixed, replace with:
    //   public typealias LossFunction = @differentiable (Model.Output, @nondiff Label) -> Loss
    public typealias LossOutputWithGradient = (Model, Context, Input, Label
                                 ) -> (Loss, Model.Output?, Model.CotangentVector)
    public typealias EventHandler = (Learner) throws -> Void
    
    /// The dataset on which the model will be trained.
    public var data: Data
    /// The optimizer used for updating model parameters along gradient vectors.
    public var optimizer: Optimizer
    /// The function that computes a loss value when given a prediction and a label.
    public var lossOutputWithGradient: LossOutputWithGradient
    /// The model being trained.
    public var model: Model
    
    //Is there a better way tonitiliaze those to not make them Optionals?
    public var currentInput: Input? = nil
    public var currentTarget: Label? = nil
    public var currentOutput: Model.Output? = nil
    
    /// The number of total epochs.
    public private(set) var epochCount: Int = .zero
    /// The current epoch.
    public private(set) var currentEpoch: Int = .zero
    /// The current gradient.
    public private(set) var currentGradient: Model.CotangentVector = .zero
    /// The current loss.
    public private(set) var currentLoss: Loss = .zero
    /// In training mode or not
    public private(set) var inTrain: Bool = false
    /// The current epoch + iteration, float between 0.0 and epochCount
    public private(set) var pctEpochs: Float = 0.0
    /// The current iteration
    public private(set) var currentIter: Int = 0
    /// The number of iterations in the current dataset
    public private(set) var iterCount: Int = 0
    
    open class Delegate {
        public init () {}
        
        open func trainingWillStart(learner: Learner) throws {}
        /// The completion of model training.
        open func trainingDidFinish(learner: Learner) throws {}
        /// A closure which will be called upon the start of an epoch.
        open func epochWillStart(learner: Learner) throws {}
        /// A closure which will be called upon the completion of an epoch.
        open func epochDidFinish(learner: Learner) throws {}
        /// A closure which will be called upon the start of model validation.
        open func validationWillStart(learner: Learner) throws {}
        /// A closure which will be called upon the start of training on a batch.
        open func batchWillStart(learner: Learner) throws {}
        /// A closure which will be called upon the completion of training on a batch.
        open func batchDidFinish(learner: Learner) throws {}
        /// A closure which will be called when a new gradient has been computed.
        open func learnerDidProduceNewGradient(learner: Learner) throws {}
        /// A closure which will be called upon the completion of an optimizer update.
        open func optimizerDidUpdate(learner: Learner) throws {}
        
        /// TODO: learnerDidProduceNewOutput and learnerDidProduceNewLoss need to
        /// be differentiable once we can have the loss function inside the Learner
    }
    public var delegates: [Delegate] = []
    
    /// The context used for layer applications.
    public private(set) var context = Context(learningPhase: .training)

    /// Creates a learner.
    ///
    /// - Parameters:
    ///   - dataset: The dataset which will be trained on.
    ///   - lossFunction: The loss function.
    ///   - optimizer: The optimizer used for updating model parameters along
    ///     gradient vectors.
    ///   - modelInitializer: The closure that produces an model to be trained.
    ///
    public init(data: Data,
                lossOutputWithGradient: @escaping LossOutputWithGradient,
                optimizer: Optimizer,
                initializingWith modelInitializer: () -> Model) {
        self.data = data
        self.optimizer = optimizer
        self.lossOutputWithGradient = lossOutputWithGradient
        self.model = modelInitializer()
    }
}

Then let's write the parts of the training loop:

In [None]:
// export
extension Learner {
    /// Trains the model on the given batch.
    ///
    /// - Parameter batch: The batch of input data and labels to be trained on.
    ///
    private func train(onBatch batch: DataBatch<Input, Label>) throws {
        (currentLoss, currentOutput, currentGradient) = lossOutputWithGradient(model, context, batch.xb, batch.yb)
        try delegates.forEach { try $0.learnerDidProduceNewGradient(learner: self) }
        optimizer.update(&model.allDifferentiableVariables, along: self.currentGradient)
    }
    
    /// Performs a training epoch on a Dataset.
    private func train(onDataset ds: Dataset<DataBatch<Input, Label>>) throws {
        iterCount = ds.count(where: {_ in true})
        for batch in ds {
            (currentInput, currentTarget) = (batch.xb, batch.yb)
            try delegates.forEach { try $0.batchWillStart(learner: self) }
            do { try train(onBatch: batch) }
            catch LearnerAction.skipBatch { break }
            try delegates.forEach { try $0.batchDidFinish(learner: self) }
        }
    }
}

And the whole fit function.

In [None]:
// export
extension Learner{
    /// Starts fitting.
    /// - Parameter epochCount: The number of epochs that will be run.
    public func fit(_ epochCount: Int) throws {
        self.epochCount = epochCount
        do {
            try delegates.forEach { try $0.trainingWillStart(learner: self) }
            for i in 0..<epochCount {
                self.currentEpoch = i
                try delegates.forEach { try $0.epochWillStart(learner: self) }
                do { try train(onDataset: data.train) }
                try delegates.forEach { try $0.validationWillStart(learner: self) }
                do { try train(onDataset: data.valid) }
                catch LearnerAction.skipEpoch { break }
                try delegates.forEach { try $0.epochDidFinish(learner: self) }
            }
            try delegates.forEach { try $0.trainingDidFinish(learner: self) }
        } catch LearnerAction.stop { return }
    }
}

### Test

In [None]:
let opt = SGD<BasicModel, Float>(learningRate: 1e-2)

In [None]:
func modelInit() -> BasicModel {return BasicModel(nIn: m, nHid: nHid, nOut: c)}

In [None]:
func lossOutputWithGrad(
    model: BasicModel,
    in context: Context,
    inputs: Tensor<Float>,
    labels: Tensor<Int32>
) -> (Tensor<Float>, BasicModel.Output, BasicModel.CotangentVector) {
    var outputs: BasicModel.Output? = nil
    let (loss, grads) = model.valueWithGradient { model -> Tensor<Float> in
        let predictions = model.applied(to: inputs, in: context)
        outputs = predictions
        return softmaxCrossEntropy(logits: predictions, labels: labels)
    }
    return (loss, outputs!, grads)
}

In [None]:
let learner = Learner(data: data, lossOutputWithGradient: lossOutputWithGrad, optimizer: opt, initializingWith: modelInit)

In [None]:
learner.fit(2)

## Let's add Callbacks!

### Train/eval

Callback classes are defined as extensions of the Learner.

In [None]:
// export
extension Learner {
    public class TrainEvalDelegate: Delegate {
        public override func trainingWillStart(learner: Learner) throws {
            learner.pctEpochs = 0.0
            learner.currentIter = 0
        }

        public override func epochWillStart(learner: Learner) throws {
            //print("Beginning epoch \(learner.currentEpoch)")
            learner.pctEpochs = Float(learner.currentEpoch)
            learner.context = Context(learningPhase: .training)
            learner.inTrain = true
        }
        
        public override func batchDidFinish(learner: Learner) throws{
            if learner.inTrain{
                learner.pctEpochs   += 1.0 / Float(learner.iterCount)
                learner.currentIter += 1
            }
        }
        
        public override func validationWillStart(learner: Learner) throws {
            learner.context = Context(learningPhase: .inference)
            learner.inTrain = false
        }
    }
}

In [None]:
let learner = Learner(data: data, lossOutputWithGradient: lossOutputWithGrad, optimizer: opt, initializingWith: modelInit)

In [None]:
learner.delegates = [Learner.TrainEvalDelegate()]

In [None]:
learner.fit(2)

### AverageMetric

In [None]:
// export
extension Learner {
    public class AvgMetric: Delegate {
        public let metrics: [(Tensor<Float>, Tensor<Int32>) -> Tensor<Float>]
        var total: Int = 0
        var partials: [Tensor<Float>] = []
        
        public init(metrics: [(Tensor<Float>, Tensor<Int32>) -> Tensor<Float>]){ self.metrics = metrics}
        
        public override func epochWillStart(learner: Learner) throws {
            total = 0
            partials = Array(repeating: Tensor(0), count: metrics.count + 1)
        }
        
        public override func batchDidFinish(learner: Learner) throws{
            if !learner.inTrain{
                if let target = learner.currentTarget as? Tensor<Int32>{
                    let bs = target.shape[0]
                    total += Int(bs)
                    partials[0] += Float(bs) * learner.currentLoss
                    for i in 1...metrics.count{
                        partials[i] += Float(bs) * metrics[i-1]((learner.currentOutput as! Tensor<Float>), target)
                    }
                }
            }
        }
        
        public override func epochDidFinish(learner: Learner) throws {
            for i in 0...metrics.count {partials[i] = partials[i] / Float(total)}
            print("Epoch \(learner.currentEpoch): \(partials)")
        }
    }
}

In [None]:
let learner = Learner(data: data, lossOutputWithGradient: lossOutputWithGrad, optimizer: opt, initializingWith: modelInit)

In [None]:
learner.delegates = [Learner.TrainEvalDelegate(), Learner.AvgMetric(metrics: [accuracy])]

In [None]:
learner.fit(2)

## Export

In [None]:
notebookToScript(fname: (Path.cwd / "04_callbacks.ipynb").string)