# Training loop callbacks

In [None]:
// TODO.
// Integrate callbacks into `train` function.

# Training loops

In [None]:
import TensorFlow

In [None]:
// Softmax cross entropy loss function.
// TODO: This should be moved into the TensorFlow library/APIs.
@differentiable(vjp: _vjpSoftmaxCrossEntropy)
func softmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
    features: Tensor<Scalar>, labels: Tensor<Scalar>
) -> Tensor<Scalar> {
    return Raw.softmaxCrossEntropyWithLogits(features: features, labels: labels).loss.mean()
}

@usableFromInline
func _vjpSoftmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
    features: Tensor<Scalar>, labels: Tensor<Scalar>
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (Tensor<Scalar>, Tensor<Scalar>)) {
    let (loss, grad) = Raw.softmaxCrossEntropyWithLogits(features: features, labels: labels)
    let batchSize = Tensor<Scalar>(features.shapeTensor[0])
    return (loss.mean(), { v in ((v / batchSize) * grad, Tensor<Scalar>(0)) })
}

In [None]:
// Example type for use with `Dataset`.
// TODO: The usage of this should be re-evaluated.
public struct Example<DataScalar, LabelScalar>: TensorGroup
    where DataScalar: TensorFlowFloatingPoint,
          LabelScalar: TensorFlowFloatingPoint {
    public var data: Tensor<DataScalar>
    public var labels: Tensor<LabelScalar>
}

In [None]:
/// A training loop.
///
/// Trains the given model at the given key path to all differentiable variables, on the given
/// dataset, using the given optimizer and loss function.
public func train<M, O: Optimizer, S>(
    _ model: inout M,
    at variablesKeyPath: WritableKeyPath<M, M.AllDifferentiableVariables>,
    on dataset: Dataset<Example<S, S>>,
    using optimizer: inout O,
    loss: @escaping @differentiable (Tensor<S>, Tensor<S>) -> Tensor<S>
) where O.Model == M, O.Scalar == S,
        M.Input == Tensor<S>, M.Output == Tensor<S>
{
    let context = Context(learningPhase: .training)
    for batch in dataset {
        let (x, y) = (batch.data, batch.labels)
        let (loss, (𝛁model, _)) = model.valueWithGradient(at: y) { (model, y) -> Tensor<S> in
            let preds = model.applied(to: x, in: context)
            return loss(preds, y)
        }
        print(loss)
        optimizer.update(&model[keyPath: variablesKeyPath], along: 𝛁model)
    }
}

In [None]:
// Example usage.
var model = Dense<Float>(inputSize: 784, outputSize: 10)
var optimizer = SGD<Dense<Float>, Float>(learningRate: 0.1)

let data = Tensor<Float>(randomNormal: [10, 10, 784])
let labels = Tensor<Float>(randomNormal: [10, 10])
let dataset = Dataset<Example<Float, Float>>(elements: Example<Float, Float>(data: data, labels: labels))

train(&model, at: \Dense<Float>.allDifferentiableVariables, on: dataset, using: &optimizer, loss: softmaxCrossEntropy)