# Forward pass analysis

In [None]:
// for local development
%install-location /notebooks/language2motion.gt/swift-install
%install-swiftpm-flags -c release
%install '.package(path: "/notebooks/language2motion.gt")' Datasets TranslationModels TextModels ModelSupport SummaryWriter LangMotionModels Checkpoints

In [None]:
import TensorFlow
import TextModels
import TranslationModels
import Foundation
import FoundationXML
import ModelSupport
import Datasets
import SummaryWriter
import LangMotionModels
import Checkpoints
import PythonKit

In [None]:
%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")

## Set training params

In [None]:
let device = Device.defaultTFEager

In [None]:
let maxTextSequenceLength =  20
let maxMotionLength =  100

In [None]:
let datasetSize: DatasetSize = .full
let batchSize = 2

In [None]:
let dataURL = URL(fileURLWithPath: "/notebooks/language2motion.gt/data/")
let motionDatasetURL = dataURL.appendingPathComponent("motion_dataset_v3.10Hz.\(datasetSize.rawValue)plist")

In [None]:
/// instantiate text processor
let vocabularyURL = dataURL.appendingPathComponent("vocab.txt")
let vocabulary: Vocabulary = try! Vocabulary(fromFile: vocabularyURL)
let tokenizer: Tokenizer = BERTTokenizer(vocabulary: vocabulary, caseSensitive: false, unknownToken: "[UNK]", maxTokenLength: nil)
let textProcessor = TextProcessor(vocabulary: vocabulary, tokenizer: tokenizer)

/// instantiate model
let config = LangMotionTransformerConfig(
    vocabSize: vocabulary.count,
    nbJoints: 47, // TODO: get value from dataset
    nbMixtures: 20,
    layerCount: 6,
    modelSize: 256,
    feedForwardSize: 1024,
    headCount: 8,
    dropoutProbability:  0.1,
    sentenceMaxPositionalLength: 100,
    motionMaxPositionalLength: 500
)

let runName = "run_16"
let epoch = 35

let runURL = dataURL.appendingPathComponent("runs/Lang2motion/\(runName)", isDirectory: true)
let checkpointURL = runURL.appendingPathComponent("checkpoints", isDirectory: true)
let motionsURL = runURL.appendingPathComponent("generated_motions", isDirectory: true)
try! FileManager().createDirectory(at: motionsURL, withIntermediateDirectories: true)

let model = LangMotionTransformer(checkpoint: checkpointURL, config: config, name: "model.e\(epoch)")

In [None]:
/// load dataset
print("\nLoading dataset...")

var dataset = try Lang2Motion(
    motionDatasetURL: motionDatasetURL,
    batchSize: batchSize,
    trainTestSplit: 1.0,
    device: device
) { (motionSample: MotionSample) -> LangMotionBatch in    
    let sentence = textProcessor.preprocess(sentence: motionSample.annotations[0], maxTextSequenceLength: maxTextSequenceLength)
    let (motionPart, target) = LangMotionBatch.preprocessTargetMotion(sampleID: motionSample.sampleID, motion: motionSample.motion, maxMotionLength: maxMotionLength)
    let source = LangMotionBatch.Source(sentence: sentence, motionPart: motionPart)
    let singleBatch = LangMotionBatch(data: source, label: target)
    return singleBatch
}

print("Dataset acquired.")

# Forward pass

In [None]:
// + create batch
// + run forward pass
// + compute loss
// TODO: visualize data:
//       - ...

In [None]:
let motionSample = dataset.motionSamples[0]
motionSample.description

In [None]:
extension LangMotionBatch {
    public static func preprocessTargetMotion2(sampleID: Int, motion: Tensor<Float>, maxMotionLength: Int) -> (motionPart: MotionPart, target: Target)
    {
        // print("preprocessTargetMotion(sampleID: \(sampleID), motion: \(motion.shape), maxMotionLength: \(maxMotionLength))")        

        let origMotionFramesCount: Tensor<Int32> = Tensor<Int32>([Int32(motion.shape[0])])
        
        var (paddedMotion, motionFlag) = motion.paddedAndCropped(to: maxMotionLength)
        // print("paddedMotion: \(paddedMotion.shape), motionFlag: \(motionFlag.shape)")
        paddedMotion = paddedMotion.expandingShape(at: 0)
        motionFlag = motionFlag.expandingShape(at: 0)

        // source (motionPart & motion flag)
        let rangeExceptLast = 0..<(paddedMotion.shape[1] - 1)
        let motionPartTensor = paddedMotion[0..., rangeExceptLast, 0...]

        let motionPartFlag = motionFlag[0..., rangeExceptLast]
        let motionPartMask = makeStandardMask(target: motionPartFlag, pad: 0) // FIXME: fix target mask

        let motionPart = MotionPart(motion: motionPartTensor, mask: motionPartMask)
        // motionPart.printMotionPart()

        // target (motion & stops)
        // FIXME: should targetTruthStop encompass current motion frame?
        let targetMotion: Tensor<Float> = paddedMotion[0..., 1..., 0...]
        let targetMotionFlag = motionFlag[0..., 1...]
        let targetStops: Tensor<Float> = 1.0 - Tensor<Float>(targetMotionFlag)

        let target = Target(sampleID: Tensor([Int32(sampleID)]), motion: targetMotion, stops: targetStops, origMotionFramesCount: origMotionFramesCount)
        // target.printTarget()
        return (motionPart: motionPart, target: target)
    }
}

In [None]:
let sentence = textProcessor.preprocess(sentence: motionSample.annotations[0], maxTextSequenceLength: maxTextSequenceLength)
let (motionPart, target) = LangMotionBatch.preprocessTargetMotion2(sampleID: motionSample.sampleID, motion: motionSample.motion, maxMotionLength: maxMotionLength)
let source = LangMotionBatch.Source(sentence: sentence, motionPart: motionPart)
let singleBatch = LangMotionBatch(data: source, label: target)

In [None]:
singleBatch.data.printSource()

In [None]:
singleBatch.label.printTarget()

In [None]:
let preds = model(singleBatch.data)
preds.printPreds()

In [None]:
// Loss function
let args = LossArgs(
        nb_joints: config.nbJoints,
        nb_mixtures: config.nbMixtures,
        mixture_regularizer_type: "None",  // ["cv", "l2", "None"]
        mixture_regularizer: 0.0,
        device: device
)

@differentiable
func embeddedNormalMixtureSurrogateLoss(y_pred: MixtureModelPreds, y_true: LangMotionBatch.Target) -> Tensor<Float> {
    let loss = normalMixtureSurrogateLoss(y_true: y_true, y_pred: y_pred, args: args)
    let n_items: Float = Float(loss.shape[0] * loss.shape[1])
    let avg_loss = loss.sum() / n_items
    return avg_loss
}

In [None]:
let loss = embeddedNormalMixtureSurrogateLoss(y_pred: preds, y_true: singleBatch.label)
loss