# Train Transformer for the Lang2motion task

In [None]:
// for colab
%install-location $cwd/swift-install
%install-swiftpm-flags -c release
%install '.package(url: "https://github.com/wojtekcz/language2motion.git", .branch("master"))' Datasets TranslationModels TextModels ModelSupport SummaryWriter LangMotionModels TrainingLoop

## What's the GPU?

In [None]:
import Foundation

func shell(_ command: String) -> String {
    let task = Process()
    let pipe = Pipe()

    task.standardOutput = pipe
    task.arguments = ["-c", command]
    task.launchPath = "/bin/bash"
    task.launch()

    let data = pipe.fileHandleForReading.readDataToEndOfFile()
    return String(data: data, encoding: .utf8)!
}

func sh(_ command: String) {
    print(shell(command))
}

sh("""
export PATH="$PATH:/opt/bin:/swift/toolchain/usr/bin"
export LD_LIBRARY_PATH="/usr/lib64-nvidia:$LD_LIBRARY_PATH"
nvidia-smi
""")

## run colab ssh server

In [None]:
// run colab ssh server
// after it finishes, interrupt cell execution
sh("bash <(curl -s https://raw.githubusercontent.com/wojtekcz/language2motion/master/notebooks/Colab/swift_colab_ssh_server.sh)")

In [None]:
sh("ps ax|grep ssh")

In [None]:
sh("kill -9 2211")  // enter ssh pid to kill the tunnel

## Imports

In [None]:
import TensorFlow
import TextModels
import TranslationModels
import Foundation
import FoundationXML
import ModelSupport
import Datasets
import SummaryWriter
import LangMotionModels
import TrainingLoop
import PythonKit
import x10_optimizers_optimizer

In [None]:
import PythonKit

%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")

In [None]:
// sh("ps ax")

In [None]:
// sh("kill 2150")

In [None]:
// sh("ls -la /content/data/runs/Lang2motion/run_19/checkpoints")
// sh("ls -la /content/checkpoints/")

In [None]:
// sh("mv /content/data/runs/Lang2motion/run_17/checkpoints/* /content/data/runs")

## Download data

In [None]:
let datasetSize: DatasetSize = .multi_midi
let dataset_name = "motion_dataset_v3.10Hz.\(datasetSize.rawValue)"
let runName = "run_55"

In [None]:
sh("""
mkdir -p /content/data/
cd /content/data/
wget -nv --show-progress -N https://github.com/wojtekcz/language2motion/releases/download/v0.3.0/\(dataset_name)tgz
wget -nv -N https://github.com/wojtekcz/language2motion/releases/download/v0.1.0/vocab.txt
tar xzvf \(dataset_name)tgz --skip-old-files
""")

In [None]:
// sh("""
// cd /content/data/
// mv motion_dataset_v3.10Hz.plist motion_dataset_v3.10Hz.multi.plist
// """)

In [None]:
// sh("""
// cd /content/data/
// mkdir -p runs/Lang2motion/\(runName)/checkpoints
// wget -nv --show-progress -N https://github.com/wojtekcz/language2motion/releases/download/v0.3.0/run_16.model.e35.tgz
// tar xzvf run_16.model.e35.tgz --skip-old-files --no-same-owner -C runs/Lang2motion/\(runName)/checkpoints
// """)

In [None]:
// sh("ln -s /content/data/runs/Lang2motion/\(runName)/checkpoints /content/checkpoints")

## Set training params

In [None]:
let batchSize = 50
let maxTextSequenceLength =  20
let maxMotionLength =  50
let nEpochs = 150
// let peakLearningRate: Float = 1e-3 // bs=50
// let peakLearningRate: Float = 5e-4
let peakLearningRate: Float = 2e-4 // bs=50
// let peakLearningRate: Float = 2e-5
// let peakLearningRate: Float = 1e-5
// let peakLearningRate: Float = 5e-6
// let peakLearningRate: Float = 5e-5
// let peakLearningRate: Float = 1e-4

let stepsPerEpoch = 53/batchSize*2 // function of training set size and batching configuration

let beta1: Float = 0.9
let beta2: Float = 0.999
let useBiasCorrection = false
let weightDecayRate: Float = 0.00

print("runName: \(runName)")
print("batchSize: \(batchSize)")
print("maxTextSequenceLength: \(maxTextSequenceLength)")
print("maxMotionLength: \(maxMotionLength)")
print("nEpochs: \(nEpochs)")
print("peakLearningRate: \(peakLearningRate)")
print("datasetSize: \(datasetSize)")
print("stepsPerEpoch: \(stepsPerEpoch)")

In [None]:
let dataURL = URL(fileURLWithPath: "/content/data/")
let motionDatasetURL = dataURL.appendingPathComponent("motion_dataset_v3.10Hz.\(datasetSize.rawValue)plist")

let logdirURL = dataURL.appendingPathComponent("runs/Lang2motion/", isDirectory: true)
let rundirURL = logdirURL.appendingPathComponent(runName, isDirectory: true)
let checkpointURL = rundirURL.appendingPathComponent("checkpoints", isDirectory: true)
try! FileManager().createDirectory(at: checkpointURL, withIntermediateDirectories: true)

## Select eager or X10 backend

In [None]:
// let device = Device.defaultXLA
let device = Device.defaultTFEager
print(device)

## Instantiate model

In [None]:
/// instantiate text processor
print("instantiate text processor")
let vocabularyURL = dataURL.appendingPathComponent("vocab.txt")
let vocabulary: Vocabulary = try! Vocabulary(fromFile: vocabularyURL)
let tokenizer: Tokenizer = BERTTokenizer(vocabulary: vocabulary, caseSensitive: false, unknownToken: "[UNK]", maxTokenLength: nil)
let textProcessor = TextProcessor(vocabulary: vocabulary, tokenizer: tokenizer)

/// instantiate model
print("instantiate model")
let modelSize = 128
let config = LangMotionTransformerConfig(
    vocabSize: vocabulary.count,
    nbJoints: 47, // TODO: get value from dataset
    nbMixtures: 20,
    layerCount: 6,
    modelSize: modelSize,
    feedForwardSize: 512,
    headCount: 4,
    dropoutProbability:  0.1,
    sentenceMaxPositionalLength: 100,
    motionMaxPositionalLength: 500,
    motionPositionalEncodingSize: 32,
    encoderSelfAttentionTemp: sqrt(Double(modelSize)),
    decoderSourceAttentionTemp: sqrt(Double(modelSize)),
    decoderSelfAttentionTemp: Double(modelSize)
)

var start_epoch = 0

/// create new model
// var model = LangMotionTransformer(config: config)

/// load model checkpoint
start_epoch = 7
var model = try! LangMotionTransformer(checkpoint: logdirURL.appendingPathComponent("run_54/checkpoints"), config: config, name: "model.e\(start_epoch)")

## Load dataset

In [None]:
// motionDatasetURL.path

In [None]:
print("\nLoading dataset...")

var dataset = try Lang2Motion(
    motionDatasetURL: motionDatasetURL,
    batchSize: batchSize,
    minMotionLength: 20,
    maxMotionLength: 50,
    trainTestSplit: 0.9,
    device: device
) { (motionSample: MotionSample) -> LangMotionBatch in    
    let sentence = textProcessor.preprocess(sentence: motionSample.annotations[0], maxTextSequenceLength: maxTextSequenceLength)
    let (motionPart, target) = LangMotionBatch.preprocessTargetMotion(sampleID: motionSample.sampleID, motion: motionSample.motion, maxMotionLength: maxMotionLength, shiftMaskRight: true)
    let source = LangMotionBatch.Source(sentence: sentence, motionPart: motionPart)
    let singleBatch = LangMotionBatch(data: source, label: target)
    return singleBatch
}

print("Dataset acquired.")

## Optimizer

In [None]:
var optimizer = x10_optimizers_optimizer.GeneralOptimizer(
    for: model,
    TensorVisitorPlan(model.differentiableVectorView),
    defaultOptimizer: makeWeightDecayedAdam(
      learningRate: peakLearningRate,
      beta1: beta1,
      beta2: beta2,
      weightDecayRate: weightDecayRate
    )
)

var scheduledLearningRate = LinearlyDecayedParameter(
  baseParameter: LinearlyWarmedUpParameter(
      baseParameter: FixedParameter<Float>(peakLearningRate),
      warmUpStepCount: 20,
      warmUpOffset: 0),
  slope: -(peakLearningRate / Float(stepsPerEpoch * nEpochs)),  // The LR decays linearly to zero.
  startStep: 10
)

public func learningRateUpdater<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
    if event == .updateStart {
        let optimizer: GeneralOptimizer<LangMotionTransformer> = loop.optimizer as! GeneralOptimizer<LangMotionTransformer>
        let step = optimizer.step + 1 // for scheduled rates and bias correction, steps start at 1
        optimizer.learningRate = scheduledLearningRate(forStep: UInt64(step))
        if useBiasCorrection {
          let f_step = Float(step)
          optimizer.learningRate *= sqrtf(1 - powf(beta2, f_step)) / (1 - powf(beta1, f_step))
        }
        // print("\noptimizer: step: \(optimizer.step), learningRate: \(optimizer.learningRate)")
    }
}

## Training helpers

In [None]:
// Loss function
let args = LossArgs(
        nb_joints: config.nbJoints,
        nb_mixtures: config.nbMixtures,
        mixture_regularizer_type: "None",  // ["cv", "l2", "None"]
        mixture_regularizer: 0.0,
        device: device
)

@differentiable(wrt: y_pred)
func embeddedNormalMixtureSurrogateLoss(y_pred: LangMotionTransformerOutput<Float>, y_true: LangMotionBatch.Target) -> Tensor<Float> {
    return normalMixtureSurrogateLoss(y_pred: y_pred.preds, y_true: y_true, args: args)
}

public func saveCheckpoint<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
    if event == .epochEnd {
        guard let epochIndex = loop.epochIndex else {
            return
        }
        let transformer: LangMotionTransformer = loop.model as! LangMotionTransformer
        try! transformer.writeCheckpoint(to: checkpointURL, name: "model.e\(epochIndex+1)")
    }
}

public class StatsRecorder {
    let summaryWriter = SummaryWriter(logdir: rundirURL, flushMillis: 30*1000)
    public var trainingStepCount = 0
    public var trainingBatchCount = 0
    public var trainingLossSum: Float = 0.0
    public var epochIndex = 0 // FIXME: Workaround

    public func writeStats<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
        if event == .batchEnd {
            guard 
            // let batchIndex = loop.batchIndex, 
            let trainingLoss = loop.lastLoss else {
                return
            }
            // print("\nbatch stats: batchIndex: \(batchIndex), trainingStepCount: \(trainingStepCount), trainingLoss: \(trainingLoss)")
            summaryWriter.writeScalarSummary(tag: "TrainingLoss", step: trainingStepCount, value:trainingLoss.scalar!)
            trainingStepCount += 1
            trainingBatchCount += 1
            trainingLossSum += Float(trainingLoss.scalar!)
        }
        if event == .epochStart {
            trainingBatchCount = 0
            trainingLossSum = 0.0
        }
        if event == .epochEnd {
            // guard let epochIndex = loop.epochIndex else {
            //     return
            // }
            let current_epoch = epochIndex + 1
            let epochTrainingLoss = trainingLossSum / Float(trainingBatchCount)
            // print("\nepoch stats: current_epoch: \(current_epoch), epochTrainingLoss: \(epochTrainingLoss)")
            summaryWriter.writeScalarSummary(tag: "EpochTrainingLoss", step: current_epoch, value: epochTrainingLoss)
        }
        if event == .fitEnd {
            summaryWriter.flush()
        }
    }
}

## Training loop

In [None]:
// start tensorboard
// cd /content/data
// tensorboard --bind_all --logdir runs/Lang2motion/

In [None]:
// let nEpochs = 1

In [None]:
let statsRecorder = StatsRecorder()

// Training loop
print("\nSetting up the training loop")
let trainingProgress = TrainingProgress(metrics: [.loss])
var trainingLoop = TrainingLoop(
  training: dataset.trainEpochs,
  validation: dataset.testBatches,
  optimizer: optimizer,
  lossFunction: embeddedNormalMixtureSurrogateLoss,
  callbacks: [trainingProgress.update, statsRecorder.writeStats, learningRateUpdater])

print("\nTraining Transformer for the Lang2motion task!")
// FIXME: epoch loop workaround for checkpoint saving
for epochIndex in start_epoch..<start_epoch+nEpochs {
    print("epoch \(epochIndex+1)/\(start_epoch + nEpochs)")
    statsRecorder.epochIndex = epochIndex
    try! trainingLoop.fit(&model, epochs: 1, on: device)
    try! model.writeCheckpoint(to: checkpointURL, name: "model.e\(epochIndex+1)")
}

try! model.writeCheckpoint(to: checkpointURL, name: "model.final")
print("\nFinished training.")