# Motion generation from checkpoints

In [None]:
// for local development
%install-location /notebooks/language2motion.gt/swift-install
%install-swiftpm-flags -c release
%install '.package(path: "/notebooks/language2motion.gt")' Datasets TranslationModels TextModels ModelSupport SummaryWriter LangMotionModels Checkpoints

In [None]:
import TensorFlow
import TextModels
import TranslationModels
import Foundation
import FoundationXML
import ModelSupport
import Datasets
import SummaryWriter
import LangMotionModels
import Checkpoints

## Set training params

In [None]:
let runName = "run_2"

In [None]:
let maxTextSequenceLength =  20
let maxMotionLength =  100

In [None]:
let datasetSize: DatasetSize = .full
let batchSize = 150

In [None]:
let dataURL = URL(fileURLWithPath: "/notebooks/language2motion.gt/data/")
let motionDatasetURL = dataURL.appendingPathComponent("motion_dataset_v3.10Hz.\(datasetSize.rawValue)plist")
let langDatasetURL = dataURL.appendingPathComponent("labels_ds_v2.csv")

In [None]:
/// instantiate text processor
let vocabularyURL = dataURL.appendingPathComponent("vocab.txt")
let vocabulary: Vocabulary = try! Vocabulary(fromFile: vocabularyURL)
let tokenizer: Tokenizer = BERTTokenizer(vocabulary: vocabulary, caseSensitive: false, unknownToken: "[UNK]", maxTokenLength: nil)
let textProcessor = TextProcessor2(vocabulary: vocabulary, tokenizer: tokenizer, maxTextSequenceLength: maxTextSequenceLength, maxMotionLength: maxMotionLength)

/// instantiate model
let vocabSize = vocabulary.count
let nbJoints = 47 // TODO: get value from dataset
let nbMixtures = 20
let layerCount: Int = 6
let modelSize: Int = 256
let feedForwardSize: Int = 1024
let headCount: Int = 8
let dropoutProbability: Double = 0.1

In [None]:
/// load dataset
print("\nLoading dataset...")

var dataset = try Lang2Motion(
    motionDatasetURL: motionDatasetURL,
    langDatasetURL: langDatasetURL,
    batchSize: batchSize
) { (example: Lang2Motion.Example) -> LangMotionBatch in    
    let singleBatch = textProcessor.preprocess(example: example)
    return singleBatch
}

print("Dataset acquired.")

## helpers

In [None]:
extension LangMotionTransformer {
    public init(checkpoint: URL, config: LangMotionTransformerConfig, name: String) throws {
        print("Loading model \"\(name)\" from \"\(checkpoint.path)\"...")
        // Try loading from the given checkpoint.
        do {
            // create reader
            let auxiliary: [String] = [
                "hparams.json"
            ]

            let reader: CheckpointReader = try CheckpointReader(
                checkpointLocation: checkpoint.appendingPathComponent(name),
                modelName: name,
                additionalFiles: auxiliary)
            
            // TODO: load config (values)
            
            // load objects            
            let scope = "model"
            let _encoder = Encoder(reader: reader, config: config, scope: scope + "/encoder")
            let _decoder = Decoder(reader: reader, config: config, scope: scope + "/decoder")
            let _motionDense = Dense<Float>(reader: reader, config: config, scope: scope + "/motionDense")
            let _embedding = Embedding<Float>(reader: reader, config: config, scope: scope + "/embedding")
            let _positionalEncoding = PositionalEncoding(size: config.modelSize, dropoutProbability: config.dropoutProbability)
            let _sourceEmbed = Sequential(_embedding, _positionalEncoding)

            let _mixtureModel = MotionGaussianMixtureModel(reader: reader, config: config, scope: scope + "/mixtureModel")
            
            self.init(encoder: _encoder, decoder: _decoder, embedding: _embedding, positionalEncoding: _positionalEncoding, 
                      motionDense: _motionDense, sourceEmbed: _sourceEmbed, mixtureModel: _mixtureModel, 
                      modelSize: config.modelSize, nbJoints: config.nbJoints, nbMixtures: config.nbMixtures)
        } catch {
            // If checkpoint is invalid, throw the error and exit.
            print("Fail to load LangMotionTransformer from checkpoint. \(error)")
            throw error
        }
    }
}

In [None]:
public func greedyDecodeMotion(model: LangMotionTransformer, sentence: String, prefix: String = "prefix", saveMotion: Bool = true) {
    // TODO: incorporate done/stop signal
    Context.local.learningPhase = .inference
    print("\ngreedyDecodeMotion(sentence: \"\(sentence)\")")

    let source = textProcessor.preprocess(sentence: sentence)
    source.printSource()

    let decodedMotion = MotionDecoder.greedyDecodeMotion(source: source, transformer: model, nbJoints: nbJoints, nbMixtures: nbMixtures, maxMotionLength: maxMotionLength)
    print("  decodedMotion: min: \(decodedMotion.min()), max: \(decodedMotion.max())")
    let descaledMotion = dataset.scaler.inverse_transform(decodedMotion)
    print("  descaledMotion.shape: \(descaledMotion.shape)")
    print("  descaledMotion: min: \(descaledMotion.min()), max: \(descaledMotion.max())")
    var imageURL: URL? = dataURL.appendingPathComponent("model_checkpoints/\(runName)/generated_motions/\(prefix).png")
    if !saveMotion { imageURL = nil }
    motionToImg(url: imageURL, motion: descaledMotion, motionFlag: nil, padTo: maxMotionLength, descr: "\(prefix), \(sentence)", cmapRange: 2.0)

    if saveMotion {
        print("Saved image: \(imageURL!.path)")
        let jointNames = dataset.trainExamples[0].motionSample.jointNames
        let mmmXMLDoc = MMMWriter.getMMMXMLDoc(jointNames: jointNames, motion: descaledMotion)
        let mmmURL = dataURL.appendingPathComponent("model_checkpoints/\(runName)/generated_motions/\(prefix).mmm.xml")
        try! mmmXMLDoc.xmlData(options: XMLNode.Options.nodePrettyPrint).write(to: mmmURL)
        print("Saved motion: \(mmmURL.path)")
    }
}

## load model checkpoint

In [None]:
// config
let config = LangMotionTransformerConfig(
    vocabSize: vocabSize,
    nbJoints: nbJoints,
    nbMixtures: nbMixtures,
    layerCount: layerCount,
    modelSize: modelSize,
    feedForwardSize: feedForwardSize,
    headCount: headCount,
    dropoutProbability: dropoutProbability
)

In [None]:
let logdirURL = dataURL.appendingPathComponent("model_checkpoints/\(runName)/checkpoints", isDirectory: true)

In [None]:
let epoch = 17

In [None]:
let model = LangMotionTransformer(checkpoint: logdirURL, config: config, name: "model.e\(epoch)")

## decode motion

In [None]:
greedyDecodeMotion(model: model, sentence: "human walks and then runs and later sits down", prefix: "epoch_\(epoch)", saveMotion: true)