# Motion generation from checkpoints

In [None]:
// for local development
%install-location /notebooks/language2motion.gt/swift-install
%install-swiftpm-flags -c release
%install '.package(path: "/notebooks/language2motion.gt")' Datasets TranslationModels TextModels ModelSupport SummaryWriter LangMotionModels Checkpoints

In [None]:
import TensorFlow
import TextModels
import TranslationModels
import Foundation
import FoundationXML
import ModelSupport
import Datasets
import SummaryWriter
import LangMotionModels
import Checkpoints
import PythonKit

In [None]:
let np  = Python.import("numpy")

In [None]:
%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")

## Set training params

In [None]:
let device = Device.defaultTFEager

In [None]:
let maxTextSequenceLength =  20
let maxMotionLength =  50

In [None]:
let datasetSize: DatasetSize = .full
let batchSize = 150

In [None]:
let dataURL = URL(fileURLWithPath: "/notebooks/language2motion.gt/data/")
let motionDatasetURL = dataURL.appendingPathComponent("motion_dataset_v3.10Hz.\(datasetSize.rawValue)plist")

In [None]:
/// instantiate text processor
let vocabularyURL = dataURL.appendingPathComponent("vocab.txt")
let vocabulary: Vocabulary = try! Vocabulary(fromFile: vocabularyURL)
let tokenizer: Tokenizer = BERTTokenizer(vocabulary: vocabulary, caseSensitive: false, unknownToken: "[UNK]", maxTokenLength: nil)
let textProcessor = TextProcessor(vocabulary: vocabulary, tokenizer: tokenizer)

// model config
let modelSize = 128
let config = LangMotionTransformerConfig(
    vocabSize: vocabulary.count,
    nbJoints: 47, // TODO: get value from dataset
    nbMixtures: 20,
    layerCount: 6,
    modelSize: modelSize,
    feedForwardSize: 512,
    headCount: 4,
    dropoutProbability:  0.1,
    sentenceMaxPositionalLength: 100,
    motionMaxPositionalLength: 500,
//     encoderSelfAttentionTemp: Double(modelSize*modelSize),
//     decoderSourceAttentionTemp: Double(modelSize*modelSize),
    encoderSelfAttentionTemp: sqrt(Double(modelSize)),
    decoderSourceAttentionTemp: sqrt(Double(modelSize)),
    decoderSelfAttentionTemp: Double(modelSize)
)

In [None]:
print("\nLoading dataset...")

var dataset = try Lang2Motion(
    motionDatasetURL: motionDatasetURL,
    batchSize: batchSize,
    minMotionLength: 20,
    maxMotionLength: 50,
    trainTestSplit: 1.0,
    device: device
) { (motionSample: MotionSample) -> LangMotionBatch in    
    let sentence = textProcessor.preprocess(sentence: motionSample.annotations[0], maxTextSequenceLength: maxTextSequenceLength)
    let (motionPart, target) = LangMotionBatch.preprocessTargetMotion(sampleID: motionSample.sampleID, motion: motionSample.motion, maxMotionLength: maxMotionLength, shiftMaskRight: true)
    let source = LangMotionBatch.Source(sentence: sentence, motionPart: motionPart)
    let singleBatch = LangMotionBatch(data: source,label: target)
    return singleBatch
}

print("Dataset acquired.")

## Helpers

In [None]:
let plt = Python.import("matplotlib.pyplot")
let np = Python.import("numpy")

func tensorShow2(_ tensor: Tensor<Float>) {
    plt.imshow(tensor.makeNumpyArray(), cmap: "Spectral")
    plt.show()
}

In [None]:
public class MotionDecoder2 {

// extension MotionDecoder2 {
    public static func greedyDecodeMotion2(
        sentence: LangMotionBatch.Sentence, 
        startMotion: Tensor<Float>?,
        transformer: LangMotionTransformer, 
        nbJoints: Int, 
        nbMixtures: Int, 
        maxMotionLength: Int,
        memoryMultiplier: Float = 1.0,
        showAttentionProbs: Bool = false
    ) -> Tensor<Float> {
        print("\nEncode:")
        print("======")
        let encoded = transformer.encode(input: sentence)
        
        if showAttentionProbs {
            encoded.allLayerOutputs.map {tensorShow2($0.attentionOutput!.attentionProbs[0, 0])}
        }
        
        let memory = encoded.lastLayerOutput * memoryMultiplier
        print("  memory.count: \(memory.shape)")     

        print("\nGenerate:")
        print("=========")

        // start with tensor for neutral motion frame
        let zeroMotionFrame = LangMotionBatch.zeroMotionFrame(nbJoints: nbJoints).expandingShape(at: 0)
        var ys: Tensor<Float> = zeroMotionFrame
        // or with supplied motion
        if startMotion != nil {
            ys = Tensor<Float>(concatenating: [zeroMotionFrame, startMotion!.expandingShape(at:0)], alongAxis: 1)
        }

        print("ys.shape: \(ys.shape)")
        
        let maxMotionLength2 = maxMotionLength-ys.shape[1]+1
        
        for step in 0..<maxMotionLength2 {
            // print("step: \(step)")
            print(".", terminator:"")
            // prepare input
            let motionPartFlag = Tensor<Int32>(repeating: 1, shape: [1, ys.shape[1]])
            // TODO: use makeSubsequentMask, b/c it doesn't do 0 padding to the right
//             let motionPartMask = LangMotionBatch.makeSubsequentMask(target: motionPartFlag, pad: 0, shiftRight: true)
//             let motionPartMask = LangMotionBatch.makeStandardMask(target: motionPartFlag, pad: 0, shiftRight: true)
            //let motionPartMask = makeStandardMaskV2(target: motionPartFlag, pad: 0)
            var motionPartMask = LangMotionBatch.makeStandardMask(target: motionPartFlag, pad: 0, shiftRight: true) // FIXME: fix target mask
            let motionLen = Int(motionPartFlag.sum().scalar!)
            motionPartMask[0, 0..<motionLen-1, 0..<motionLen] -= 1
            motionPartMask = abs(motionPartMask)

            var motionStartFlag = Tensor<Float>(zeros: [ys.shape[1], 1]).expandingShape(at: 0) // FIXME: refactor getting motionStartFlag
            motionStartFlag[0, 0, 0] = Tensor(1.0)
            let motionPart = LangMotionBatch.MotionPart(motion: ys, mask: motionPartMask, startFlag: motionStartFlag, motionFlag: motionPartFlag.expandingShape(at: 2))
            let source = LangMotionBatch.Source(sentence: sentence, motionPart: motionPart)
            // print("\(step), sourceAttentionMask.shape: \(source.sourceAttentionMask.shape)")
            // decode motion
            let decoded = transformer.decode(sourceMask: source.sourceAttentionMask, motionPart: motionPart, memory: memory)
            
            if showAttentionProbs {
                decoded.allLayerOutputs.map {tensorShow2($0.sourceAttentionOutput!.attentionProbs[0, 0])}
                decoded.allLayerOutputs.map {tensorShow2($0.targetAttentionOutput!.attentionProbs[0, 0])}
            }
            
            let mixtureModelInput = Tensor<Float>(concatenating: decoded.allResults, alongAxis: 2)
            let mixtureModelInput2 = mixtureModelInput[0...,-1].expandingShape(at: 0)
            let singlePreds = transformer.mixtureModel(mixtureModelInput2)
            
            // perform sampling
            let (sampledMotion, _, _) = MotionDecoder.performNormalMixtureSampling(
                preds: singlePreds, nb_joints: nbJoints, nb_mixtures: nbMixtures, maxMotionLength: maxMotionLength)
            
            // concatenate motion
            ys = Tensor(concatenating: [ys, sampledMotion.expandingShape(at: 0)], alongAxis: 1)
        }
        print()
        return ys.squeezingShape(at:0)[1...]
    }
}

In [None]:
public struct SampleMotionClip {
    var sampleID: Int
    var start: Int = 0
    var length: Int = 1
}

public func getClippedMotionFrames(dataset: Lang2Motion, clipInfo: SampleMotionClip?) -> Tensor<Float>? {
    if clipInfo != nil {
    
    let ms: MotionSample = dataset.motionSamples.filter { $0.sampleID == clipInfo!.sampleID } [0]
    let clippedMotionFrames = ms.motion[clipInfo!.start..<clipInfo!.start+clipInfo!.length]
    return clippedMotionFrames
    } else {
        return nil
    }
}

In [None]:
public func greedyDecodeMotion2(dataset: Lang2Motion, model: LangMotionTransformer, sentence: String, leadingFrames: SampleMotionClip?, prefix: String = "prefix", 
                                saveMotion: Bool = true, memoryMultiplier: Float = 0.0, motionsURL: URL?, maxMotionLength: Int, showAttentionProbs: Bool = true) {
    let startMotion: Tensor<Float>? = getClippedMotionFrames(dataset: dataset, clipInfo: leadingFrames)
    var leadingFramesStr = "0"
    if startMotion != nil {
        leadingFramesStr = "\(startMotion!.shape[0])"
    }
    // TODO: incorporate done/stop signal
    Context.local.learningPhase = .inference
    print("\ngreedyDecodeMotion(sentence: \"\(sentence)\")")

    let processedSentence = textProcessor.preprocess(sentence: sentence, maxTextSequenceLength: maxTextSequenceLength)
    processedSentence.printSentence()

    let decodedMotion = MotionDecoder2.greedyDecodeMotion2(
        sentence: processedSentence, 
        startMotion: startMotion,
        transformer: model, 
        nbJoints: config.nbJoints, 
        nbMixtures: config.nbMixtures, 
        maxMotionLength: maxMotionLength,
        memoryMultiplier: memoryMultiplier,
        showAttentionProbs: showAttentionProbs
    )
    print("  decodedMotion: min: \(decodedMotion.min()), max: \(decodedMotion.max())")
    let descaledMotion = dataset.scaler.inverse_transform(decodedMotion)
    print("  descaledMotion.shape: \(descaledMotion.shape)")
    print("  descaledMotion: min: \(descaledMotion.min()), max: \(descaledMotion.max())")
    var imageURL: URL? = nil
    
    if !saveMotion { imageURL = nil } else {
        imageURL = motionsURL!.appendingPathComponent("\(prefix).png")
    }
    // use joint groupping
    let grouppedJointsMotion = MotionSample.grouppedJoints(motion: descaledMotion, jointNames: dataset.motionSamples[0].jointNames)
    motionToImg(url: imageURL, motion: grouppedJointsMotion, motionFlag: nil, padTo: maxMotionLength, descr: "\(sentence), LF: \(leadingFramesStr)", cmapRange: 1.0)

    if saveMotion {
        print("Saved image: \(imageURL!.path)")
        let jointNames = dataset.motionSamples[0].jointNames
        let mmmXMLDoc = MMMWriter.getMMMXMLDoc(jointNames: jointNames, motion: descaledMotion)
        let mmmURL = motionsURL!.appendingPathComponent("\(prefix).mmm.xml")
        try! mmmXMLDoc.xmlData(options: XMLNode.Options.nodePrettyPrint).write(to: mmmURL)
        print("Saved motion: \(mmmURL.path)")
    }
}

In [None]:
func showMotionSample(_ motionSample: MotionSample) {
    let motion = motionSample.motion
    let descaledMotion = dataset.scaler.inverse_transform(motion)
    let sentence = "sample_id=\(motionSample.sampleID), ann=\(motionSample.annotations[0])"

    print("motion: min: \(motion.min()), max: \(motion.max())")
    print("descaledMotion.shape: \(descaledMotion.shape)")
    print("descaledMotion: min: \(descaledMotion.min()), max: \(descaledMotion.max())")

    // use joint groupping
    let jointNames = dataset.motionSamples[0].jointNames
    let grouppedJointsMotion = MotionSample.grouppedJoints(motion: descaledMotion, jointNames: dataset.motionSamples[0].jointNames)
    motionToImg(url: nil, motion: grouppedJointsMotion, motionFlag: nil, padTo: maxMotionLength, descr: sentence, cmapRange: 1.0)
}

In [None]:
func showMotion(motion: Tensor<Float>) {
    let descaledMotion = dataset.scaler.inverse_transform(motion)
    let grouppedJointsMotion = MotionSample.grouppedJoints(motion: descaledMotion, jointNames: dataset.motionSamples[0].jointNames)
    motionToImg(url: nil, motion: grouppedJointsMotion, motionFlag: nil, padTo: maxMotionLength, descr: "", cmapRange: 1.0)
}

In [None]:
func saveMotionToMMM(motion: Tensor<Float>, mmmURL: URL) {
    let descaledMotion = dataset.scaler.inverse_transform(motion)
    let jointNames = dataset.motionSamples[0].jointNames
    let mmmXMLDoc = MMMWriter.getMMMXMLDoc(jointNames: jointNames, motion: descaledMotion)
    try! mmmXMLDoc.xmlData(options: XMLNode.Options.nodePrettyPrint).write(to: mmmURL)
    print("Saved motion: \(mmmURL.path)")
}

In [None]:
let runName = "run_50"
let epoch = 88

In [None]:
let runURL = dataURL.appendingPathComponent("runs/Lang2motion/\(runName)", isDirectory: true)
let checkpointURL = runURL.appendingPathComponent("checkpoints", isDirectory: true)
let motionsURL = runURL.appendingPathComponent("generated_motions", isDirectory: true)
try! FileManager().createDirectory(at: motionsURL, withIntermediateDirectories: true)

// let model = LangMotionTransformer(checkpoint: checkpointURL, config: config, name: "model.e\(epoch)")
// let model = LangMotionTransformer(checkpoint: checkpointURL, config: config, name: "model.final")

## Decode using leading motion frames

### Find suitable motion sample

In [None]:
let annotations = dataset.langRecs
annotations.count

In [None]:
let dics = annotations[0..<3].map { ["sampleID": "\($0.sampleID)", "text": $0.text] }
dics

# save annotations

## searching

In [None]:
let search = "hand"
let filteredAnns = annotations.filter { $0.text.contains(search) }
print(filteredAnns.count)
let startIdx = 0
filteredAnns[startIdx..<startIdx+min(10, filteredAnns.count)].map { (sampleID: $0.sampleID, ann: $0.text) }

### Select motion sample

In [None]:
let selAnn = filteredAnns[8]
let selSampleInfo = (sampleID: selAnn.sampleID, text: selAnn.text, length: selAnn.motionSample.motion.shape[0])

print("Selected motion sample")
print(selSampleInfo)
showMotionSample(selAnn.motionSample)
saveMotionToMMM(motion: selAnn.motionSample.motion, mmmURL: motionsURL.appendingPathComponent("sample.mmm.xml"))

### Clip motion

In [None]:
let clipInfo = SampleMotionClip(sampleID: selSampleInfo.sampleID, start: 5, length: 10)

In [None]:
let clippedMotionFrames: Tensor<Float>? = getClippedMotionFrames(dataset: dataset, clipInfo: clipInfo)
print("\n**** \(clipInfo) ****\n")
print("Actual length: \(clippedMotionFrames!.shape[0])")
print("clippedMotionFrames: min: \(clippedMotionFrames!.min()), max: \(clippedMotionFrames!.max())")
showMotion(motion: clippedMotionFrames!)
saveMotionToMMM(motion: clippedMotionFrames!, mmmURL: motionsURL.appendingPathComponent("clip.mmm.xml"))

## Load model checkpoint

In [None]:
let runName = "run_51"
let epoch = 150

In [None]:
let runURL = dataURL.appendingPathComponent("runs/Lang2motion/\(runName)", isDirectory: true)
let checkpointURL = runURL.appendingPathComponent("checkpoints", isDirectory: true)
let motionsURL = runURL.appendingPathComponent("generated_motions", isDirectory: true)
try! FileManager().createDirectory(at: motionsURL, withIntermediateDirectories: true)

let model = LangMotionTransformer(checkpoint: checkpointURL, config: config, name: "model.e\(epoch)")
// let model = LangMotionTransformer(checkpoint: checkpointURL, config: config, name: "model.final")

### Generate motion

In [None]:
var genNum = 1

In [None]:
var s: String = ""
var lf: SampleMotionClip?

In [None]:
// s = "A person is walking forwards five steps."
s = "A person is walking forwards."
// lf = SampleMotionClip(sampleID: 1, start: 26, length: 2)
lf = nil

// s = "A person plays the guitar."
// lf = SampleMotionClip(sampleID: 1438, start: 14, length: 10)

// s = "The human plays air guitar and sways ans stands still."
// s = "The human walks in the straight line."
// s = "Someone is jogging."

// s = "a person waves with his both arms"
// s = "a person is waving his hand."
// s = "a person waves with its right hand"
// s = "a person raises his right hand"
// s = "Someone raises a hand"

// s = "A person runs."
// s = "The human is running"
// lf = SampleMotionClip(sampleID: 449, start: 14, length: 10)

// s = "A person kneels down."
// s = "A human walking backwards"
// s = "A person walks 4 steps forward."

// s = "A person performs a high kick"
// lf = SampleMotionClip(sampleID: 610, start: 5, length: 10)
// s = "A person is standing up from kneeling."

In [None]:
greedyDecodeMotion2(dataset: dataset, model: model, sentence: s, leadingFrames: lf, 
    prefix: "epoch_\(epoch)_motion_\(genNum)", 
    saveMotion: true, memoryMultiplier: 1.0, motionsURL: motionsURL,
    maxMotionLength: 100, showAttentionProbs: false
)
genNum += 1