# Generate motion

In [None]:
// for local development
%install-location /notebooks/language2motion.gt/swift-install
%install-swiftpm-flags -c release
%install '.package(path: "/notebooks/language2motion.gt")' Datasets TranslationModels TextModels ModelSupport SummaryWriter LangMotionModels

In [None]:
import TensorFlow
import TextModels
import TranslationModels
import Foundation
import FoundationXML
import ModelSupport
import Datasets
import SummaryWriter
import LangMotionModels

In [None]:
import PythonKit

%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")

## What's the GPU?

In [None]:
import Foundation

func shell(_ command: String) -> String {
    let task = Process()
    let pipe = Pipe()

    task.standardOutput = pipe
    task.arguments = ["-c", command]
    task.launchPath = "/bin/bash"
    task.launch()

    let data = pipe.fileHandleForReading.readDataToEndOfFile()
    return String(data: data, encoding: .utf8)!
}

func sh(_ command: String) {
    print(shell(command))
}

## Download data

In [None]:
let datasetSize: DatasetSize = .mini
let dataset_name = "motion_dataset_v3.10Hz.\(datasetSize.rawValue)"

## Set training params

In [None]:
let runName = "run_1"
let batchSize = 2
// let batchSize = 150
let maxTextSequenceLength =  20
let maxMotionLength =  10
let nEpochs = 5
let learningRate: Float = 5e-4

print("runName: \(runName)")
print("batchSize: \(batchSize)")
print("maxTextSequenceLength: \(maxTextSequenceLength)")
print("maxMotionLength: \(maxMotionLength)")
print("nEpochs: \(nEpochs)")
print("learningRate: \(learningRate)")

// let dataURL = URL(fileURLWithPath: "/content/data/")
let dataURL = URL(fileURLWithPath: "/notebooks/language2motion.gt/data/")
let motionDatasetURL = dataURL.appendingPathComponent("\(dataset_name)plist")
let langDatasetURL = dataURL.appendingPathComponent("labels_ds_v2.csv")

## Select eager or X10 backend

In [None]:
// let device = Device.defaultXLA
let device = Device.defaultTFEager
print(device)

## Instantiate model

In [None]:
/// instantiate text processor
let vocabularyURL = dataURL.appendingPathComponent("vocab.txt")
let vocabulary: Vocabulary = try! Vocabulary(fromFile: vocabularyURL)
let tokenizer: Tokenizer = BERTTokenizer(vocabulary: vocabulary, caseSensitive: false, unknownToken: "[UNK]", maxTokenLength: nil)
let textProcessor = TextProcessor2(vocabulary: vocabulary, tokenizer: tokenizer, maxTextSequenceLength: maxTextSequenceLength, maxMotionLength: maxMotionLength)

/// instantiate model
let vocabSize = vocabulary.count
let nbJoints = 47 // TODO: get value from dataset
let layerCount: Int = 6
let modelSize: Int = 256
let feedForwardSize: Int = 1024
let headCount: Int = 8
let dropoutProbability: Double = 0.1

var transformer = LangMotionTransformer(
    vocabSize: vocabSize, 
    nbJoints: nbJoints,
    layerCount: layerCount, 
    modelSize: modelSize, 
    feedForwardSize: feedForwardSize, 
    headCount: headCount, 
    dropoutProbability: dropoutProbability
)

let nbMixtures = 20
// TODO: integrate MotionGaussianMixtureModel with Generator
var mixtureModel = MotionGaussianMixtureModel(inputSize: nbJoints, nbJoints: nbJoints, nbMixtures: nbMixtures)
// mixtureModel.move(to: device)

var model = LangMotionModel(transformer: transformer, mixtureModel: mixtureModel)
model.move(to: device)

## Load dataset

In [None]:
print("\nLoading dataset...")

var dataset = try Lang2Motion(
    motionDatasetURL: motionDatasetURL,
    langDatasetURL: langDatasetURL,
    batchSize: batchSize
) { (example: Lang2Motion.Example) -> LangMotionBatch in    
    let singleBatch = textProcessor.preprocess(example: example)
    return singleBatch
}

print("Dataset acquired.")

## Generate motion

In [None]:
public func greedyDecodeMotion(sentence: String, prefix: String = "prefix", showMotion: Bool = false) -> Tensor<Float> {
    // TODO: incorporate done/stop signal
    // TODO: save mmm file
    Context.local.learningPhase = .inference
    print("\ngreedyDecodeMotion(sentence: \"\(sentence)\")")

    let source = textProcessor.preprocess(sentence: sentence)
    source.printSource()

    print("\nEncode:")
    print("======")
    let memory = model.transformer.encode(input: source)
    print("  memory.count: \(memory.shape)")

    print("\nGenerate:")
    print("=========")
    // tensor for neutral motion frame
    var ys: Tensor<Float> = Tensor<Float>(repeating:0.0, shape: [1, 1, nbJoints])
    for _ in 0..<maxMotionLength {
        // prepare input
        let targetMask = Tensor<Float>(subsequentMask(size: ys.shape[1]))
        let target = LangMotionBatch.Target(motion: ys, mask: targetMask)

        // decode motion
        let out = model.transformer.decode(sourceMask: source.mask, target: target, memory: memory)
        let singlePreds = model.mixtureModel(model.transformer.generate(input: out[0...,-1].expandingShape(at: 0)))
        
        // perform sampling
        let (sampledMotion, log_probs, done) = MotionDecoder.performNormalMixtureSampling(
            preds: singlePreds, nb_joints: nbJoints, nb_mixtures: nbMixtures, maxMotionLength: maxMotionLength)
        
        // concatenate motion
        ys = Tensor(concatenating: [ys, sampledMotion.expandingShape(at: 0)], alongAxis: 1)        
    }

    // descale motion    
    let descaled_motion = dataset.scaler.inverse_transform(ys.squeezingShape(at:0))
    print("  descaled_motion.shape: \(descaled_motion.shape)")

    let imageURL = !showMotion ? dataURL.appendingPathComponent("motion_images/\(prefix).png") : nil
    motionToImg(url: imageURL, motion: descaled_motion, motionFlag: nil, padTo: maxMotionLength, descr: "\(prefix), \(sentence)")
    if !showMotion {
        print("Saved image: \(imageURL!.path)")
    }
    return descaled_motion
}

In [None]:
let motion = greedyDecodeMotion(sentence: "human is walking", prefix: "foo9", showMotion: true)

# save to mmm

In [None]:
motion.shape

In [None]:
let jointNames = dataset.trainExamples[0].motionSample.jointNames

In [None]:
let xmlDoc = MMMWriter.getMMMXMLDoc(jointNames: jointNames, motion: motion)

In [None]:
xmlDoc.xmlData(options: XMLNode.Options.nodePrettyPrint).write(to: dataURL.appendingPathComponent("generated_motions/generated_1.mmm.xml"))

In [None]:
print(xmlDoc.xmlString(options: XMLNode.Options.nodePrettyPrint))