# Training Utilities

This notebook presents a design of training utilities. 

In [5]:
%install '.package(url: "https://github.com/mxcl/Path.swift", from: "0.16.1")' Path

Installing packages:
	.package(url: "https://github.com/mxcl/Path.swift", from: "0.16.1")
		Path
With SwiftPM flags: []
Working in: /tmp/tmpmar21_vs
Fetching https://github.com/mxcl/Path.swift
Completed resolution in 2.58s
Cloning https://github.com/mxcl/Path.swift
Resolving https://github.com/mxcl/Path.swift at 0.16.2
Compile Swift Module 'Path' (9 sources)
Compile Swift Module 'jupyterInstalledPackages' (1 sources)
Linking ./.build/x86_64-unknown-linux/debug/libjupyterInstalledPackages.so
Initializing Swift...
Loading library...
Installation complete!


In [33]:
import TensorFlow

## Training example data structure

A training example data structure consists of training data and a label.

In [34]:
/// A training example, containing training data and a label. Depending on `Data` and
/// `Label`'s implementations, the contents may represent a batch.
public struct Example<Data: Differentiable, Label> {
    public var data: Data
    public var label: Label
    
    public init(data: Data, label: Label) {
        self.data = data
        self.label = label
    }
}

## Trainer (learner)

A `Trainer` is responsible for initializing and training a model on a given dataset. It can be considered as a controller and an environment of model training.

### Core properties

`Trainer` contains three kinds of properties:
* Core units: `model`, `dataset`, `optimizer`, `lossFunction`
* Training states: `epochCount`, `currentEpoch`, `currentGradient`, `currentLoss`
* Delegates

In [35]:
public enum TrainerAction: Error {
    case skipEpoch
    case skipBatch
    case stop
}

In [56]:
/// A model trainer, responsible for initializing and training a model on a given dataset.
// NOTE: When TF-421 is fixed, make `Label` not constrained to `Differentiable`.
public final class Trainer<D: Collection, Label: Differentiable,
                           L: Differentiable & BinaryFloatingPoint,
                           O: TensorFlow.Optimizer & AnyObject>
    where D.Element == Example<O.Model.Input, Label>,
          O.Scalar: Differentiable, L == L.CotangentVector
{
    // Common type aliases.
    public typealias Dataset = D
    public typealias Loss = L
    public typealias Optimizer = O
    public typealias Model = Optimizer.Model
    public typealias Data = Model.Input
    public typealias Variables = Model.AllDifferentiableVariables
    // NOTE: When TF-421 is fixed, replace with:
    //   public typealias LossFunction = @differentiable (Model.Output, @nondiff Label) -> Loss
    public typealias LossFunction = @differentiable (Model.Output, Label) -> Loss
    public typealias EventHandler = (Trainer) throws -> Void
    
    /// The dataset on which the model will be trained.
    public let dataset: Dataset
    /// The optimizer used for updating model parameters along gradient vectors.
    public var optimizer: Optimizer
    /// The function that computes a loss value when given a prediction and a label.
    public var lossFunction: LossFunction
    /// The model being trained.
    public var model: Model
    
    /// The number of total epochs.
    public private(set) var epochCount: Int = .zero
    /// The current epoch.
    public private(set) var currentEpoch: Int = .zero
    /// The current gradient.
    public private(set) var currentGradient: Model.CotangentVector = .zero
    /// The current loss.
    public private(set) var currentLoss: Loss = .zero
    
    open class Delegate {
        open func trainingWillStart(trainer: Trainer) throws {}
        /// The completion of model training.
        open func trainingDidFinish(trainer: Trainer) throws {}
        /// A closure which will be called upon the start of an epoch.
        open func epochWillStart(trainer: Trainer) throws {}
        /// A closure which will be called upon the completion of an epoch.
        open func epochDidFinish(trainer: Trainer) throws {}
        /// A closure which will be called upon the start of model validation.
        open func validationWillStart(trainer: Trainer) throws {}
        /// A closure which will be called upon the start of training on a batch.
        open func batchWillStart(trainer: Trainer) throws {}
        /// A closure which will be called upon the completion of training on a batch.
        open func batchDidFinish(trainer: Trainer) throws {}
        /// A closure which will be called when a new loss has been computed.
        open func trainerDidProduceNewLoss(trainer: Trainer) throws {}
        /// A closure which will be called when a new gradient has been computed.
        open func trainerDidProduceNewGradient(trainer: Trainer) throws {}
        /// A closure which will be called upon the completion of an optimizer update.
        open func optimizerDidUpdate(trainer: Trainer) throws {}
    }
    public var delegates: [Delegate] = []
    
    /// The context used for layer applications.
    private let context = Context(learningPhase: .training)

    /// Creates a trainer.
    ///
    /// - Parameters:
    ///   - dataset: The dataset which will be trained on.
    ///   - lossFunction: The loss function.
    ///   - optimizer: The optimizer used for updating model parameters along
    ///     gradient vectors.
    ///   - modelInitializer: The closure that produces an model to be trained.
    ///
    public init(dataset: Dataset,
                lossFunction: @escaping LossFunction,
                optimizer: Optimizer,
                initializingWith modelInitializer: () -> Model) {
        self.dataset = dataset
        self.optimizer = optimizer
        self.lossFunction = lossFunction
        self.model = modelInitializer()
    }
}

### Methods

The core method on `Trainer` is `fit(epochCount:)`.

In [58]:
extension Trainer {
    /// Trains the model on the given batch.
    ///
    /// - Parameter batch: The batch of input data and labels to be trained on.
    ///
    private func train(on batch: Dataset.Element) throws {
        // NOTE: When the "subset of parameters" bug is fixed, replace with:
        //   let (loss, grad) = model.valueWithGradient { model -> Loss in
        //      let y = model.applied(to: batch.data, in: context)
        //      return lossFunction(y, batch.label)
        //   }
        let (loss, (grad, _)) = model.valueWithGradient(at: batch.label) {
            (model, label) -> Loss in
            let y = model.applied(to: batch.data, in: context)
            return lossFunction(y, label)
        }
        // NOTE: Put this inside `valueWithGradient`'s trailing closure when differentiation
        // supports throwing functions.
        currentLoss = loss
        try delegates.forEach { try $0.trainerDidProduceNewLoss(trainer: self) }
        currentGradient = grad
        try delegates.forEach { try $0.trainerDidProduceNewGradient(trainer: self) }
        optimizer.update(&model.allDifferentiableVariables, along: grad)
        try delegates.forEach { try $0.batchDidFinish(trainer: self) }
    }
    
    /// Performs the `i`-th training epoch.
    ///
    /// - Parameter index: The epoch index.
    private func train(atEpoch index: Int) throws {
        currentEpoch = index
        try delegates.forEach { try $0.epochWillStart(trainer: self) }
        for batch in dataset {
            try delegates.forEach { try $0.batchWillStart(trainer: self) }
            do { try train(on: batch) }
            catch TrainerAction.skipBatch { break }
            try delegates.forEach { try $0.batchDidFinish(trainer: self) }
        }
        try delegates.forEach { try $0.epochDidFinish(trainer: self) }
    }

    /// Starts training.
    ///
    /// - Parameter epochCount: The number of epochs that will be run.
    ///
    public func train(epochCount: Int) throws {
        self.epochCount = epochCount
        self.currentEpoch = 0
        do {
            try delegates.forEach { try $0.trainingWillStart(trainer: self) }
            for i in 0..<epochCount {
                do { try train(atEpoch: i) }
                catch TrainerAction.skipEpoch { break }
            }
            try delegates.forEach { try $0.trainingDidFinish(trainer: self) }
        } catch TrainerAction.stop { return }
    }
}

In the short term, we call it `Learner` instead.

In [59]:
public typealias Learner = Trainer

## Handlers

In [70]:
extension Trainer {
    public class Recorder: Delegate {
        public var losses: [Loss] = []
        public var learningRates: [Optimizer.Scalar] = []

        public override func trainingWillStart(trainer: Trainer) throws {
            losses = []
            learningRates = []
        }

        public override func optimizerDidUpdate(trainer: Trainer) throws {
            losses.append(trainer.currentLoss)
            learningRates.append(trainer.optimizer.learningRate)
        }
    }
    
    public class ParameterScheduler<Parameter>: Delegate {
        public var keyPath: WritableKeyPath<Model, Parameter>
        public var schedule: (Optimizer.Scalar) -> Parameter

        public init(keyPath: WritableKeyPath<Model, Parameter>,
             schedule: @escaping (Optimizer.Scalar) -> Parameter) {
            self.keyPath = keyPath
            self.schedule = schedule
        }
        
        public override func batchWillStart(trainer: Trainer) throws {
            let ratio = Optimizer.Scalar(trainer.currentEpoch)
                / Optimizer.Scalar(trainer.epochCount)
            trainer.model[keyPath: keyPath] = schedule(ratio)
        }
    }
}

## Examples

### Simple training loop

In [None]:
let outputCount = 10

struct MyModel: Layer {
    var layer1 = Dense<Float>(inputSize: 2, outputSize: 4, activation: relu)
    var layer2 = Dense<Float>(inputSize: 4, outputSize: 2, activation: relu)
    
    @differentiable
    func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
        return input.sequenced(in: context, through: layer1, layer2)
    }
}

// let trainer = Trainer()