# Train Transformer for the Lang2motion task

In [None]:
// for colab
%install-location $cwd/swift-install
%install-swiftpm-flags -c release
%install '.package(url: "https://github.com/wojtekcz/language2motion.git", .branch("master"))' Datasets TranslationModels TextModels ModelSupport SummaryWriter LangMotionModels TrainingLoop

In [None]:
import TensorFlow
import TextModels
import TranslationModels
import Foundation
import FoundationXML
import ModelSupport
import Datasets
import SummaryWriter
import LangMotionModels
import TrainingLoop
import PythonKit
import x10_optimizers_optimizer

In [None]:
import PythonKit

%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")

## What's the GPU?

In [None]:
import Foundation

func shell(_ command: String) -> String {
    let task = Process()
    let pipe = Pipe()

    task.standardOutput = pipe
    task.arguments = ["-c", command]
    task.launchPath = "/bin/bash"
    task.launch()

    let data = pipe.fileHandleForReading.readDataToEndOfFile()
    return String(data: data, encoding: .utf8)!
}

func sh(_ command: String) {
    print(shell(command))
}

sh("""
export PATH="$PATH:/opt/bin:/swift/toolchain/usr/bin"
export LD_LIBRARY_PATH="/usr/lib64-nvidia:$LD_LIBRARY_PATH"
nvidia-smi
""")

In [None]:
// sh("ps ax")

In [None]:
// sh("kill 2150")

In [None]:
sh("ls -l /content/data/runs/Lang2motion/run_17/checkpoints")

In [None]:
sh("mv /content/data/runs/Lang2motion/run_17/checkpoints/* /content/data/runs")

## Download data

In [None]:
let datasetSize: DatasetSize = .full
let dataset_name = "motion_dataset_v3.10Hz.\(datasetSize.rawValue)"

In [None]:
// sh("mkdir -p /content/data/motion_images/")
sh("""
cd /content/data/
wget -nv --show-progress -N https://github.com/wojtekcz/language2motion/releases/download/v0.3.0/\(dataset_name)tgz
wget -nv -N https://github.com/wojtekcz/language2motion/releases/download/v0.1.0/vocab.txt
tar xzvf \(dataset_name)tgz --skip-old-files
""")

In [None]:
//https://github.com/wojtekcz/language2motion/releases/download/v0.3.0/run_17.model.e35.tgz
sh("""
cd /content/data/
wget -nv --show-progress -N https://github.com/wojtekcz/language2motion/releases/download/v0.3.0/run_17.model.e35.tgz
tar xzvf run_17.model.e35.tgz --skip-old-files -C runs/Lang2motion/run_18/checkpoints
""")

## Set training params

In [None]:
let runName = "run_18"
// let batchSize = 4
let batchSize = 50
let maxTextSequenceLength =  20
let maxMotionLength =  100
let nEpochs = 10
// let peakLearningRate: Float = 5e-4
let peakLearningRate: Float = 2e-5

let stepsPerEpoch = 127 // function of training set size and batching configuration

let beta1: Float = 1.0 //0.9
let beta2: Float = 1.0 //0.999
let useBiasCorrection = false

// let datasetSize: DatasetSize = .midi

print("runName: \(runName)")
print("batchSize: \(batchSize)")
print("maxTextSequenceLength: \(maxTextSequenceLength)")
print("maxMotionLength: \(maxMotionLength)")
print("nEpochs: \(nEpochs)")
print("peakLearningRate: \(peakLearningRate)")
print("datasetSize: \(datasetSize)")

In [None]:
let dataURL = URL(fileURLWithPath: "/content/data/")
let motionDatasetURL = dataURL.appendingPathComponent("motion_dataset_v3.10Hz.\(datasetSize.rawValue)plist")

let logdirURL = dataURL.appendingPathComponent("runs/Lang2motion/\(runName)", isDirectory: true)
let checkpointURL = logdirURL.appendingPathComponent("checkpoints", isDirectory: true)
try! FileManager().createDirectory(at: checkpointURL, withIntermediateDirectories: true)

## Select eager or X10 backend

In [None]:
// let device = Device.defaultXLA
let device = Device.defaultTFEager
print(device)

## X10 warm-up

In [None]:
// let eagerTensor1 = Tensor([0.0, 1.0, 2.0])
// let eagerTensor2 = Tensor([1.5, 2.5, 3.5])
// let eagerTensorSum = eagerTensor1 + eagerTensor2
// print(eagerTensorSum)
// print(eagerTensor1.device)
// let x10Tensor2 = Tensor([1.5, 2.5, 3.5], on: Device.defaultXLA)
// print(x10Tensor2.device)

## Instantiate model

In [None]:
/// instantiate text processor
print("instantiate text processor")
let vocabularyURL = dataURL.appendingPathComponent("vocab.txt")
let vocabulary: Vocabulary = try! Vocabulary(fromFile: vocabularyURL)
let tokenizer: Tokenizer = BERTTokenizer(vocabulary: vocabulary, caseSensitive: false, unknownToken: "[UNK]", maxTokenLength: nil)
let textProcessor = TextProcessor(vocabulary: vocabulary, tokenizer: tokenizer)

/// instantiate model
print("instantiate model")
let config = LangMotionTransformerConfig(
    vocabSize: vocabulary.count,
    nbJoints: 47, // TODO: get value from dataset
    nbMixtures: 20,
    layerCount: 6,
    modelSize: 256,
    feedForwardSize: 1024,
    headCount: 8,
    dropoutProbability:  0.1,
    sentenceMaxPositionalLength: 100,
    motionMaxPositionalLength: 500
)

/// create new model
// var model = LangMotionTransformer(config: config)

/// load model checkpoint
print("checkpointURL: \(checkpointURL.path)")
let start_epoch = 35
var model = try! LangMotionTransformer(checkpoint: checkpointURL, config: config, name: "model.e\(start_epoch)")

## Load dataset

In [None]:
print("\nLoading dataset...")

var dataset = try Lang2Motion(
    motionDatasetURL: motionDatasetURL,
    batchSize: batchSize,
    minMotionLength: 10,
    maxMotionLength: 100,
    trainTestSplit: 1.0,
    demultiplyMotions: true,
    device: device
) { (motionSample: MotionSample) -> LangMotionBatch in    
    let sentence = textProcessor.preprocess(sentence: motionSample.annotations[0], maxTextSequenceLength: maxTextSequenceLength)
    let (target2, motionPart) = LangMotionBatch.preprocessTargetMotion(sampleID: motionSample.sampleID, motion: motionSample.motion, maxMotionLength: maxMotionLength)
    let source = LangMotionBatch.Source(sentence: sentence, motionPart: motionPart)
    let singleBatch = LangMotionBatch(data: source,label: target2)
    return singleBatch
}

print("Dataset acquired.")

## Optimizer

In [None]:
var optimizer = x10_optimizers_optimizer.GeneralOptimizer(
    for: model,
    TensorVisitorPlan(model.differentiableVectorView),
    defaultOptimizer: makeWeightDecayedAdam(
      learningRate: peakLearningRate,
      beta1: beta1,
      beta2: beta2
    )
)

var scheduledLearningRate = LinearlyDecayedParameter(
  baseParameter: LinearlyWarmedUpParameter(
      baseParameter: FixedParameter<Float>(peakLearningRate),
      warmUpStepCount: 20,
      warmUpOffset: 0),
  slope: -(peakLearningRate / Float(stepsPerEpoch * nEpochs)),  // The LR decays linearly to zero.
  startStep: 10
)

public func learningRateUpdater<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
    if event == .updateStart {
        let optimizer: GeneralOptimizer<LangMotionTransformer> = loop.optimizer as! GeneralOptimizer<LangMotionTransformer>
        let step = optimizer.step + 1 // for scheduled rates and bias correction, steps start at 1
        optimizer.learningRate = scheduledLearningRate(forStep: UInt64(step))
        if useBiasCorrection {
          let f_step = Float(step)
          optimizer.learningRate *= sqrtf(1 - powf(beta2, f_step)) / (1 - powf(beta1, f_step))
        }
        // print("\noptimizer: step: \(optimizer.step), learningRate: \(optimizer.learningRate)")
    }
}

## Training helpers

In [None]:
// Loss function
let args = LossArgs(
        nb_joints: config.nbJoints,
        nb_mixtures: config.nbMixtures,
        mixture_regularizer_type: "None",  // ["cv", "l2", "None"]
        mixture_regularizer: 0.0,
        device: device
)

@differentiable
func embeddedNormalMixtureSurrogateLoss(y_pred: MixtureModelPreds, y_target: LangMotionBatch.Target) -> Tensor<Float> {
    // TODO: create tensor on device
    let y_true = TargetTruth(copying: TargetTruth(motion: y_target.targetTruth, stops: y_target.targetTruthStop), to: device)
    let loss = normalMixtureSurrogateLoss(y_true: y_true, y_pred: y_pred, args: args)
    let n_items: Float = Float(loss.shape[0] * loss.shape[1])
    let avg_loss = loss.sum() / n_items
    return avg_loss
}

public func saveCheckpoint<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
    if event == .epochEnd {
        guard let epochIndex = loop.epochIndex else {
            return
        }
        let transformer: LangMotionTransformer = loop.model as! LangMotionTransformer
        try! transformer.writeCheckpoint(to: checkpointURL, name: "model.e\(epochIndex+1)")
    }
}

public class StatsRecorder {
    let summaryWriter = SummaryWriter(logdir: logdirURL, flushMillis: 30*1000)
    public var trainingStepCount = 0
    public var trainingBatchCount = 0
    public var trainingLossSum: Float = 0.0
    public var epochIndex = 0 // FIXME: Workaround

    public func writeStats<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
        if event == .batchEnd {
            guard 
            // let batchIndex = loop.batchIndex, 
            let trainingLoss = loop.lastLoss else {
                return
            }
            // print("\nbatch stats: batchIndex: \(batchIndex), trainingStepCount: \(trainingStepCount), trainingLoss: \(trainingLoss)")
            summaryWriter.writeScalarSummary(tag: "TrainingLoss", step: trainingStepCount, value:trainingLoss.scalar!)
            trainingStepCount += 1
            trainingBatchCount += 1
            trainingLossSum += Float(trainingLoss.scalar!)
        }
        if event == .epochStart {
            trainingBatchCount = 0
            trainingLossSum = 0.0
        }
        if event == .epochEnd {
            // guard let epochIndex = loop.epochIndex else {
            //     return
            // }
            let current_epoch = epochIndex + 1
            let epochTrainingLoss = trainingLossSum / Float(trainingBatchCount)
            // print("\nepoch stats: current_epoch: \(current_epoch), epochTrainingLoss: \(epochTrainingLoss)")
            summaryWriter.writeScalarSummary(tag: "EpochTrainingLoss", step: current_epoch, value: epochTrainingLoss)
        }
        if event == .fitEnd {
            summaryWriter.flush()
        }
    }
}

## Training loop

In [None]:
let statsRecorder = StatsRecorder()

// Training loop
print("\nSetting up the training loop")
let trainingProgress = TrainingProgress(metrics: [.loss])
var trainingLoop = TrainingLoop(
  training: dataset.trainEpochs,
  validation: dataset.testBatches,
  optimizer: optimizer,
  lossFunction: embeddedNormalMixtureSurrogateLoss,
  callbacks: [trainingProgress.update, statsRecorder.writeStats, learningRateUpdater])

print("\nTraining Transformer for the Lang2motion task!")
// FIXME: epoch loop workaround for checkpoint saving
for epochIndex in start_epoch..<start_epoch+nEpochs {
    print("epoch \(epochIndex+1)/\(start_epoch + nEpochs)")
    statsRecorder.epochIndex = epochIndex
    try! trainingLoop.fit(&model, epochs: 1, on: device)
    try! model.writeCheckpoint(to: checkpointURL, name: "model.e\(epochIndex+1)")
}

try! model.writeCheckpoint(to: checkpointURL, name: "model.final")
print("\nFinished training.")