# Train Transformer for the Motion2lang task

## Get sources

In [None]:
// for colab
%install-location $cwd/swift-install
%install-swiftpm-flags -c release
%install '.package(url: "https://github.com/wojtekcz/language2motion.git", .branch("master"))' Datasets TranslationModels TextModels ModelSupport SummaryWriter MotionLangModels TrainingLoop

## What's the GPU?

In [None]:
import Foundation

func shell(_ command: String) -> String {
    let task = Process()
    let pipe = Pipe()

    task.standardOutput = pipe
    task.arguments = ["-c", command]
    task.launchPath = "/bin/bash"
    task.launch()

    let data = pipe.fileHandleForReading.readDataToEndOfFile()
    return String(data: data, encoding: .utf8)!
}

func sh(_ command: String) {
    print(shell(command))
}

sh("""
export PATH="$PATH:/opt/bin:/swift/toolchain/usr/bin"
export LD_LIBRARY_PATH="/usr/lib64-nvidia:$LD_LIBRARY_PATH"
nvidia-smi
""")

## run colab ssh server

In [None]:
// run colab ssh server
// after it finishes, interrupt cell execution
sh("bash <(curl -s https://raw.githubusercontent.com/wojtekcz/language2motion/master/notebooks/Colab/swift_colab_ssh_server.sh)")

In [None]:
sh("ps ax|grep ssh")

In [None]:
sh("kill -9 2211")  // enter ssh pid to kill the tunnel

## Imports

In [None]:
import TensorFlow
import TextModels
import TranslationModels
import Foundation
import ModelSupport
import Datasets
import SummaryWriter
import MotionLangModels
import TrainingLoop
import x10_optimizers_optimizer

In [None]:
import PythonKit

%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")

## Download data

In [None]:
let dataset_name = "motion_dataset_v3.10Hz.multi"

In [None]:
sh("mkdir -p /content/data/")
sh("""
cd /content/data/
wget -nv --show-progress -N https://github.com/wojtekcz/language2motion/releases/download/v0.3.0/\(dataset_name).tgz
wget -nv -N https://github.com/wojtekcz/language2motion/releases/download/v0.1.0/vocab.txt
tar xzvf \(dataset_name).tgz --skip-old-files
""")

## Set training params

In [None]:
let runName = "run_5"
let batchSize = 300
let maxMotionLength = 50
let maxTextSequenceLength = 40
let nEpochs = 5
let peakLearningRate: Float = 5e-4

let stepsPerEpoch = 168/batchSize*2 // function of training set size and batching configuration

let beta1: Float = 0.9
let beta2: Float = 0.999
let useBiasCorrection = false

let datasetSize: DatasetSize = .multi_full
// let datasetSize: DatasetSize = .micro

print("runName: \(runName)")
print("batchSize: \(batchSize)")
print("maxMotionLength: \(maxMotionLength)")
print("maxTextSequenceLength: \(maxTextSequenceLength)")
print("nEpochs: \(nEpochs)")
print("peakLearningRate: \(peakLearningRate)")
print("datasetSize: \(datasetSize)")
print("stepsPerEpoch: \(stepsPerEpoch)")

let dataURL = URL(fileURLWithPath: "/content/data/")
let motionDatasetURL = dataURL.appendingPathComponent("motion_dataset_v3.10Hz.\(datasetSize.rawValue)plist")

let logdirURL = dataURL.appendingPathComponent("runs/Motion2lang/\(runName)", isDirectory: true)
let checkpointURL = logdirURL.appendingPathComponent("checkpoints", isDirectory: true)

## Select eager or X10 backend

In [None]:
let device = Device.defaultXLA
// let device = Device.defaultTFEager
print("backend: \(device)")

## X10 warm-up

In [None]:
let eagerTensor1 = Tensor([0.0, 1.0, 2.0])
let eagerTensor2 = Tensor([1.5, 2.5, 3.5])
let eagerTensorSum = eagerTensor1 + eagerTensor2
// print(eagerTensorSum)
// print(eagerTensor1.device)
let x10Tensor2 = Tensor([1.5, 2.5, 3.5], on: Device.defaultXLA)
// print(x10Tensor2.device)

## Instantiate model

In [None]:
// instantiate text processor
let vocabularyURL = dataURL.appendingPathComponent("vocab.txt")
let vocabulary: Vocabulary = try! Vocabulary(fromFile: vocabularyURL)
let tokenizer: Tokenizer = BERTTokenizer(vocabulary: vocabulary, caseSensitive: false, unknownToken: "[UNK]", maxTokenLength: nil)
let textProcessor = LegacyTextProcessor(vocabulary: vocabulary, tokenizer: tokenizer)

// instantiate model
let modelSize = 128
let config = MotionLangTransformerConfig(
    vocabSize: vocabulary.count,
    nbJoints: 47,
    layerCount: 6,
    modelSize: modelSize,
    feedForwardSize: 512,
    headCount: 4,
    dropoutProbability: 0.1,
    sentenceMaxPositionalLength: 100,
    motionMaxPositionalLength: 500
)

var start_epoch = 0

/// create new model
var model = MotionLangTransformer(config: config)

/// load model checkpoint
// print("checkpointURL: \(checkpointURL.path)")
// start_epoch = 2
// var model = try! LangMotionTransformer(checkpoint: checkpointURL, config: config, name: "model.e\(start_epoch)")
// model = try! MotionLangTransformer(checkpoint: checkpointURL, config: config, name: "model.final")

## Load dataset

In [None]:
print("\nLoading dataset...")

var dataset = try Motion2Lang(
    motionDatasetURL: motionDatasetURL,
    batchSize: batchSize,
    minMotionLength: 20,
    maxMotionLength: 100,
    trainTestSplit: 0.9,
    device: device
) { (motionSample: MotionSample) -> MotionLangBatch in    
    let singleBatch = textProcessor.preprocess(motionSample: motionSample, maxMotionLength: maxMotionLength, maxTextSequenceLength: maxTextSequenceLength)
    return singleBatch
}

print("Dataset acquired.")

## Test model with one batch

In [None]:
// // get a batch
// print("\nOne batch (MotionLangBatch):")
// var epochIterator = dataset.trainEpochs.enumerated().makeIterator()
// let epoch = epochIterator.next()
// let batches = Array(epoch!.1)
// let batch: MotionLangBatch = batches[0]
// print("type: \(type(of:batch))")
// print("motion.shape: \(batch.motion.shape)")
// // print("motionFlag.shape: \(batch.motionFlag.shape)")
// print("mask.shape: \(batch.mask.shape)")
// print("origMotionFramesCount.shape: \(batch.origMotionFramesCount.shape)")
// print("origMotionFramesCount: \(batch.origMotionFramesCount)")
// print("targetTokenIds.shape: \(batch.targetTokenIds.shape)")
// print("targetMask.shape: \(batch.targetMask.shape)")
// print("targetTruth.shape: \(batch.targetTruth.shape)")

In [None]:
// // run one batch
// print("\nRun one batch:")
// print("==============")
// let deviceBatch = MotionLangBatch(copying: batch, to: device)
// let output = model(deviceBatch)
// print("output.shape: \(output.shape)")

## Optimizer

In [None]:
var optimizer = x10_optimizers_optimizer.GeneralOptimizer(
    for: model,
    TensorVisitorPlan(model.differentiableVectorView),
    defaultOptimizer: makeWeightDecayedAdam(
      learningRate: peakLearningRate,
      beta1: beta1,
      beta2: beta2
    )
)

var scheduledLearningRate = LinearlyDecayedParameter(
  baseParameter: LinearlyWarmedUpParameter(
      baseParameter: FixedParameter<Float>(peakLearningRate),
      warmUpStepCount: 20,
      warmUpOffset: 0),
  slope: -(peakLearningRate / Float(stepsPerEpoch * nEpochs)),  // The LR decays linearly to zero.
  startStep: 10
)

public func learningRateUpdater<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
    if event == .updateStart {
        let optimizer: GeneralOptimizer<MotionLangTransformer> = loop.optimizer as! GeneralOptimizer<MotionLangTransformer>
        let step = optimizer.step + 1 // for scheduled rates and bias correction, steps start at 1
        optimizer.learningRate = scheduledLearningRate(forStep: UInt64(step))
        if useBiasCorrection {
          let f_step = Float(step)
          optimizer.learningRate *= sqrtf(1 - powf(beta2, f_step)) / (1 - powf(beta1, f_step))
        }
        // print("\noptimizer: step: \(optimizer.step), learningRate: \(optimizer.learningRate)")
    }
}

## Training helpers

In [None]:
/// stats recorder
public class StatsRecorder {
    let summaryWriter = SummaryWriter(logdir: logdirURL, flushMillis: 30*1000)
    public var trainingStepCount = 0
    public var trainingBatchCount = 0
    public var trainingLossSum: Float = 0.0
    public var epochIndex = 0 // FIXME: Workaround

    public func writeStats<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
        if event == .batchEnd {
            guard
            // let batchIndex = loop.batchIndex,
            let trainingLoss = loop.lastLoss else {
                return
            }
            // print("\nbatch stats: batchIndex: \(batchIndex), trainingStepCount: \(trainingStepCount), trainingLoss: \(trainingLoss)")
            summaryWriter.writeScalarSummary(tag: "TrainingLoss", step: trainingStepCount, value:trainingLoss.scalar!)
            trainingStepCount += 1
            trainingBatchCount += 1
            trainingLossSum += Float(trainingLoss.scalar!)
        }
        if event == .epochStart {
            trainingBatchCount = 0
            trainingLossSum = 0.0
        }
        if event == .epochEnd {
            // guard let epochIndex = loop.epochIndex else {
            //     return
            // }
            let current_epoch = epochIndex + 1
            let epochTrainingLoss = trainingLossSum / Float(trainingBatchCount)
            // print("\nepoch stats: current_epoch: \(current_epoch), epochTrainingLoss: \(epochTrainingLoss)")
            summaryWriter.writeScalarSummary(tag: "EpochTrainingLoss", step: current_epoch, value: epochTrainingLoss)
        }
        if event == .fitEnd {
            summaryWriter.flush()
        }
    }
}

let statsRecorder = StatsRecorder()

In [None]:
@differentiable(wrt: y_pred)
func embeddedSoftmaxCrossEntropy(y_pred: Tensor<Float>, y_true: MotionLangBatch.MLTarget) -> Tensor<Float> {
    let resultSize = y_true.targetTruth.shape.last! * y_true.targetTruth.shape.first!
    let logits = y_pred.reshaped(to: [resultSize, -1])
    let labels = y_true.targetTruth.reshaped(to: [-1])
    // TODO: ignore padded entries
    return softmaxCrossEntropy(logits: logits, labels: labels)
}

## Set up decoding

In [None]:
func greedyDecode(model: MotionLangTransformer, input: MotionLangBatch.MLSource, maxLength: Int, startSymbol: Int32) -> Tensor<Int32> {
    let memory = model.encode(input: input).lastLayerOutput
    var ys = Tensor(repeating: startSymbol, shape: [1,1])
    for _ in 0..<maxLength {
        let motionPartFlag = Tensor<Int32>(repeating: 1, shape: [1, ys.shape[1]])
        var motionPartMask = MotionLangBatch.makeStandardMask(target: motionPartFlag, pad: 0, shiftRight: true)
        let motionLen = Int(motionPartFlag.sum().scalar!)
        motionPartMask[0, 0..<motionLen-1, 0..<motionLen] -= 1
        motionPartMask = abs(motionPartMask)

        let decoderInput = MotionLangBatch.MLSource(sampleID: input.sampleID, motion: input.motion,
                                     mask: input.mask,
                                     origMotionFramesCount: input.origMotionFramesCount,
                                     targetTokenIds: ys,
                                     targetMask: motionPartMask
                                     )
        let out = model.decode(input: decoderInput, memory: memory).lastLayerOutput
        let prob = model.generate(input: out[0...,-1])
        let nextWord = Int32(prob.argmax().scalarized())
        ys = Tensor(concatenating: [ys, Tensor(repeating: nextWord, shape: [1,1])], alongAxis: 1) // , on: device
    }
    return ys
}

In [None]:
func greedyDecodeSample(_ sample_id: Int, maxLength: Int = 15) {
    let motionSample = dataset.motionSampleDict[sample_id]!
    print("\nsample: \(motionSample.sampleID), \"\(motionSample.annotations[0])\", motion: \(motionSample.timesteps[-1]) sec (\(motionSample.motion.shape[0]) frames)")

    let singleBatch = textProcessor.preprocess(motionSample: motionSample, maxMotionLength: maxMotionLength, maxTextSequenceLength: maxTextSequenceLength)
    let out = greedyDecode(model: model, input: singleBatch.source, maxLength: maxLength, startSymbol: textProcessor.bosId)
    let outputStr = textProcessor.decode(tensor: out)
    print("decoded: \"\(outputStr)\"")
}

In [None]:
let samplesToDecode = [
//    ["sampleID": 733, "text": "Ala ma kota."], // for .micro dataset
    ["sampleID": 449, "text": "A person runs forward."],
    ["sampleID": 3921, "text": "A human is swimming."],
    ["sampleID": 843, "text": "A person walks."],
    ["sampleID": 1426, "text": "A person plays the air guitar."],
    ["sampleID": 1292, "text": "A person performs a squat."],
    ["sampleID": 1315, "text": "A human raises their left foot and touches it with the right hand."]
]

## Training loop

In [None]:
// start tensorboard
// cd /content/data
// tensorboard --bind_all --logdir runs/Motion2lang/

In [None]:
// let nEpochs = 10

In [None]:
// Training loop
print("\nSetting up the training loop")
let trainingProgress = TrainingProgress(metrics: [.loss])
var trainingLoop: TrainingLoop = TrainingLoop(
    training: dataset.trainEpochs,
    validation: dataset.testBatches,
    optimizer: optimizer,
    lossFunction:  embeddedSoftmaxCrossEntropy,
    callbacks: [trainingProgress.update, statsRecorder.writeStats, learningRateUpdater]
//    callbacks: []
)

print("\nTraining Transformer for the Motion2lang task!")
// FIXME: epoch loop workaround for checkpoint saving
for epochIndex in start_epoch..<start_epoch+nEpochs {
    print("epoch \(epochIndex+1)/\(start_epoch + nEpochs)")
    statsRecorder.epochIndex = epochIndex
    try! trainingLoop.fit(&model, epochs: 1, on: device)
    try! model.writeCheckpoint(to: checkpointURL, name: "model.e\(epochIndex+1)")

    Context.local.learningPhase = .inference
    model.move(to: Device.defaultTFEager)
    for sample in samplesToDecode {
        greedyDecodeSample(sample["sampleID"] as! Int, maxLength: 20)
    }
    model.move(to: device)
}

try! model.writeCheckpoint(to: checkpointURL, name: "model.final")
print("\nFinished training.")

## Generate motion description

In [None]:
// let sample_id = 2410
// let sample_id = 446
let sample_id = 449
// let sample_id = 733
Context.local.learningPhase = .inference
model.move(to: Device.defaultTFEager)
greedyDecodeSample(sample_id, maxLength: 100)
model.move(to: device)