# Save Lang2motion transformer model

In [None]:
// for local development
%install-location /notebooks/language2motion.gt/swift-install
%install-swiftpm-flags -c release
%install '.package(path: "/notebooks/language2motion.gt")' Datasets TranslationModels TextModels ModelSupport SummaryWriter LangMotionModels Checkpoints

In [None]:
import TensorFlow
import TextModels
import TranslationModels
import Foundation
import ModelSupport
import Datasets
import SummaryWriter
import LangMotionModels
import Checkpoints

## Set training params

In [None]:
let maxTextSequenceLength =  20
let maxMotionLength =  100

In [None]:
let dataURL = URL(fileURLWithPath: "/notebooks/language2motion.gt/data/")

## Instantiate model

In [None]:
/// instantiate text processor
let vocabularyURL = dataURL.appendingPathComponent("vocab.txt")
let vocabulary: Vocabulary = try! Vocabulary(fromFile: vocabularyURL)
let tokenizer: Tokenizer = BERTTokenizer(vocabulary: vocabulary, caseSensitive: false, unknownToken: "[UNK]", maxTokenLength: nil)
let textProcessor = TextProcessor2(vocabulary: vocabulary, tokenizer: tokenizer, maxTextSequenceLength: maxTextSequenceLength, maxMotionLength: maxMotionLength)

/// instantiate model
let vocabSize = vocabulary.count
let nbJoints = 47 // TODO: get value from dataset
let nbMixtures = 20
let layerCount: Int = 6
let modelSize: Int = 256
let feedForwardSize: Int = 1024
let headCount: Int = 8
let dropoutProbability: Double = 0.1

var model = LangMotionTransformer(
    vocabSize: vocabSize, 
    nbJoints: nbJoints,
    nbMixtures: nbMixtures,
    layerCount: layerCount, 
    modelSize: modelSize, 
    feedForwardSize: feedForwardSize, 
    headCount: headCount, 
    dropoutProbability: dropoutProbability
)

## play with writer and reader

In [None]:
let temporaryDirectory = dataURL.appendingPathComponent("CheckpointsTests", isDirectory: true)

## save parts of the model

In [None]:
public protocol ExportableLayer {
    var nameMappings: [String: String] { get }
}

In [None]:
model.embedding.embeddings.shape

In [None]:
extension LangMotionTransformer: ExportableLayer {
    public var nameMappings: [String: String] {        
        [
            "encoder": "encoder",
            "decoder": "decoder",
            "embedding": "embedding",
            "positionalEncoding": "positionalEncoding",
            "motionDense": "motionDense",
            "mixtureModel": "mixtureModel",
        ]
        // modelSize: Int
        // nbJoints: Int
        // nbMixtures: Int
    }
}

In [None]:
extension Encoder: ExportableLayer {
    public var nameMappings: [String: String] { 
        [
            "layers": "layers",
            "norm": "norm"
        ]
    }
}

In [None]:
extension Decoder: ExportableLayer {
    public var nameMappings: [String: String] { 
        [
            "layers": "layers",
            "norm": "norm"
        ]
    }
}

In [None]:
extension Array: ExportableLayer {
    public var nameMappings: [String: String] { ["h": "\(type(of:self))".components(separatedBy: ["<", ">"])[1] + "_h" ] }
}

In [None]:
extension MotionGaussianMixtureModel: ExportableLayer {
    public var nameMappings: [String: String] {
        [
            "linearMixtureMeans": "linearMixtureMeans",
            "linearMixtureVars": "linearMixtureVars",
            "linearMixtureWeights": "linearMixtureWeights",
            "linearStop": "linearStop",
            // inputSize: Int
            // nbJoints: Int
            // nbMixtures: Int
            // outputSize: Int
        ] 
    }
}

In [None]:
extension Embedding: ExportableLayer {
    public var nameMappings: [String: String] { ["embeddings": "embeddings"] }
}

In [None]:
extension LayerNorm: ExportableLayer {
    public var nameMappings: [String: String] { ["offset": "offset", "scale": "scale"] }
}

In [None]:
extension Dense: ExportableLayer {
    public var nameMappings: [String: String] { ["weight": "weight", "bias": "bias"] }
}

In [None]:
extension MultiHeadAttention: ExportableLayer {
    public var nameMappings: [String: String] { 
        [
            // sourceSize: Int
            // targetSize: Int
            // headCount: Int
            // eadSize: Int
            // queryActivation: Activation<Scalar>
            // keyActivation: Activation<Scalar>
            // valueActivation: Activation<Scalar>
            // matrixResult: Bool

            "queryWeight": "queryWeight",
            "queryBias": "queryBias",
            "keyWeight": "keyWeight",
            "keyBias": "keyBias",
            "valueWeight": "valueWeight",
            "valueBias": "valueBias",
            // attentionDropout: Dropout<Scalar>
        ] 
    }
}

In [None]:
extension TransformerEncoderLayer2: ExportableLayer {
    public var nameMappings: [String: String] { 
        [
            "selfAttention": "selfAttention",
            "feedForward": "feedForward",
            "sublayers": "sublayers",
        ] 
    }
}

In [None]:
extension PositionwiseFeedForward: ExportableLayer {
    public var nameMappings: [String: String] { 
        [
            "dense1": "dense1",
            "dense2": "dense2",
            "dropout": "dropout",
        ] 
    }
}

In [None]:
extension SublayerConnection: ExportableLayer {
    public var nameMappings: [String: String] { 
        [
            "norm": "norm",
            "dropout": "dropout",
        ] 
    }
}

In [None]:
extension TransformerDecoderLayer: ExportableLayer {
    public var nameMappings: [String: String] { 
        [
            "selfAttention": "selfAttention",
            "sourceAttention": "sourceAttention",
            "feedForward": "feedForward",
            "sublayers": "sublayers",            
        ] 
    }
}

In [None]:
public func recursivelyObtainTensors(
    _ obj: Any, scope: String? = nil, tensors: inout [String: Tensor<Float>], separator: String
) {
    
    let m = Mirror(reflecting: obj)
    let nameMappings: [String: String]
    if let exportableLayer = obj as? ExportableLayer {
        nameMappings = exportableLayer.nameMappings
    } else {
        if (obj is Int) || (obj is Bool) || (obj is Tensor<Float>) || 
           (obj is Double) || (obj is Float) || (obj is Dropout<Float>) ||
           (obj is Parameter<Float>) || (obj is PositionalEncoding)
        {}
        else {
            let s = "\(scope!) -> \(type(of:obj))"
            if !s.contains("Tensor") {
                print(s)
            }
        }
        nameMappings = [:]
    }

    var repeatedLabels: [String: Int] = [:]
    func suffix(for label: String) -> String {
        if let currentSuffix = repeatedLabels[label] {
            repeatedLabels[label] = currentSuffix + 1
            return "\(currentSuffix + 1)"
        } else {
            repeatedLabels[label] = 0
            return "0"
        }
    }

    let hasSuffix = (m.children.first?.label == nil)

    var path = scope
    for child in m.children {
        let label = child.label ?? "h"

        if let remappedLabel = nameMappings[label] {
            let labelSuffix = hasSuffix ? suffix(for: remappedLabel) : ""
            let conditionalSeparator = remappedLabel == "" ? "" : separator

            path = (scope != nil ? scope! + conditionalSeparator : "") + remappedLabel + labelSuffix
            if let tensor = child.value as? Tensor<Float> {
                tensors[path!] = tensor
            }
        }
        recursivelyObtainTensors(child.value, scope: path, tensors: &tensors, separator: separator)
    }
}

In [None]:
func writeCheckpoint(to location: URL, name: String) throws {
    var tensors = [String: Tensor<Float>]()
    recursivelyObtainTensors(model, scope: "model", tensors: &tensors, separator: "/")
    
    tensors.keys.sorted().map {print($0)}
    
    let writer = CheckpointWriter(tensors: tensors)
    try writer.write(to: location, name: name)
}

In [None]:
writeCheckpoint(to: temporaryDirectory, name: "model1")

## Reader

In [None]:
public struct LangMotionTransformerConfig { //: Codable {
    public let vocabSize: Int
    public let nbJoints: Int
    public let nbMixtures: Int
    public let layerCount: Int
    public let modelSize: Int
    public let feedForwardSize: Int
    public let headCount: Int
    public let dropoutProbability: Double

//     enum CodingKeys: String, CodingKey {
//         case vocabSize = "vocabSize"
//     }
}

In [None]:
protocol InitializableFromPythonCheckpoint {
    init(reader: CheckpointReader, config: LangMotionTransformerConfig, scope: String)
}

In [None]:
extension Dense: InitializableFromPythonCheckpoint {
    public init(reader: CheckpointReader, config: LangMotionTransformerConfig, scope: String) {
        self.init(
            weight: reader.readTensor(name: scope + "/weight"),
            bias: reader.readTensor(name: scope + "/bias"),
            activation: identity
        )
    }
}

In [None]:
extension Embedding: InitializableFromPythonCheckpoint {
    public init(reader: CheckpointReader, config: LangMotionTransformerConfig, scope: String) {
        self.init(
            embeddings: reader.readTensor(name: scope + "/embeddings")
        )
    }
}

In [None]:
extension MotionGaussianMixtureModel: InitializableFromPythonCheckpoint {
    public init(reader: CheckpointReader, config: LangMotionTransformerConfig, scope: String) {
        self.init(
            inputSize: config.modelSize,
            nbJoints: config.nbJoints,
            nbMixtures: config.nbMixtures,
            linearMixtureMeans: Dense<Float>(reader: reader, config: config, scope: scope + "/linearMixtureMeans"),
            linearMixtureVars: Dense<Float>(reader: reader, config: config, scope: scope + "/linearMixtureVars"),
            linearMixtureWeights: Dense<Float>(reader: reader, config: config, scope: scope + "/linearMixtureWeights"),
            linearStop: Dense<Float>(reader: reader, config: config, scope: scope + "/linearStop")
        )
    }
}

In [None]:
extension LayerNorm: InitializableFromPythonCheckpoint {
    public init(reader: CheckpointReader, config: LangMotionTransformerConfig, scope: String) {
        self.init(
            offset: reader.readTensor(name: scope + "/offset"),
            scale: reader.readTensor(name: scope + "/scale"),
            axis: 2,
            epsilon: 0.001)
        // FIXME: axis & epsilon defaults
    }
}

In [None]:
extension MultiHeadAttention: InitializableFromPythonCheckpoint {
    public init(reader: CheckpointReader, config: LangMotionTransformerConfig, scope: String) {
        self.init(
            sourceSize: config.modelSize,
            targetSize: config.modelSize,
            headCount: config.headCount,
            headSize: config.modelSize/config.headCount,
            queryActivation: identity,
            keyActivation: identity,
            valueActivation: identity,
            attentionDropoutProbability: 0,
            matrixResult: false,
            queryWeight: reader.readTensor(name: scope + "/queryWeight"),
            queryBias: reader.readTensor(name: scope + "/queryBias"),
            keyWeight: reader.readTensor(name: scope + "/keyWeight"),
            keyBias: reader.readTensor(name: scope + "/keyBias"),
            valueWeight: reader.readTensor(name: scope + "/valueWeight"),
            valueBias: reader.readTensor(name: scope + "/valueBias")
        )
    }
}

In [None]:
extension PositionwiseFeedForward: InitializableFromPythonCheckpoint {
    public init(reader: CheckpointReader, config: LangMotionTransformerConfig, scope: String) {
        self.init(
            dense1: Dense<Float>(reader: reader, config: config, scope: scope + "/dense1"),
            dense2: Dense<Float>(reader: reader, config: config, scope: scope + "/dense2"),
            dropout: Dropout<Float>(probability: config.dropoutProbability)
        )
    }
}

In [None]:
extension TransformerEncoderLayer2: InitializableFromPythonCheckpoint {
    public init(reader: CheckpointReader, config: LangMotionTransformerConfig, scope: String) {
        let _selfAttention = MultiHeadAttention(
            reader: reader, config: config, scope: scope + "/selfAttention")
        let _feedForward = PositionwiseFeedForward(reader: reader, config: config, scope: scope + "/feedForward")
        // TODO: serialize/deserialize sublayers [SublayerConnection]
        self.init(
            size: config.modelSize, 
            selfAttention: _selfAttention, 
            feedForward: _feedForward, 
            dropoutProb: config.dropoutProbability
        )
    }
}

In [None]:
extension Encoder: InitializableFromPythonCheckpoint {
    public init(reader: CheckpointReader, config: LangMotionTransformerConfig, scope: String) {
        let _layers = (0..<config.layerCount).map { i in
            TransformerEncoderLayer2(reader: reader, config: config, scope: scope + "/layers/TransformerEncoderLayer2_h\(i)")
        }
        let _norm = LayerNorm<Float>(reader: reader, config: config, scope: scope + "/norm")
        self.init(layers: _layers, norm: _norm)
    }
}

In [None]:
extension LangMotionTransformer {
    public init(checkpoint: URL) throws {
        // Try loading from the given checkpoint.
        do {
            // config
            let _vocabSize: Int = 100
            let _nbJoints: Int = 47
            let _nbMixtures: Int = 20
            let _layerCount: Int = 6 
            let _modelSize: Int = 256
            let _feedForwardSize: Int = 1024
            let _headCount: Int = 8
            let _dropoutProbability: Double = 0.1
            let config = LangMotionTransformerConfig(
                vocabSize: 100,
                nbJoints: 47,
                nbMixtures: 20,
                layerCount: 6,
                modelSize: 256,
                feedForwardSize: 1024,
                headCount: 8,
                dropoutProbability: 0.1
            )
            
            // create reader
            let auxiliary: [String] = [
                "hparams.json"
            ]

            let reader: CheckpointReader = try CheckpointReader(
                checkpointLocation: checkpoint.appendingPathComponent("model1"),
                modelName: "model1",
                additionalFiles: auxiliary)
            let scope = "model"
            
            print(reader)
            
            // TODO: load config (values)
            // TODO: * load weights
            
            // create objects            
            let _encoder = Encoder(reader: reader, config: config, scope: scope + "/encoder")
            
            // TODO: serialize Decoder
            // TODO: deserialize Decoder
            let _attention = MultiHeadAttention(sourceSize: _modelSize,
                                                targetSize: _modelSize,
                                                headCount: _headCount,
                                                headSize: _modelSize/_headCount,
                                                matrixResult: false)
            let _feedForward = PositionwiseFeedForward(dimensionalityModel: _modelSize,
                                                       innerLayerDimensionality: _feedForwardSize, 
                                                       dropProbability: _dropoutProbability)
            let _decoder = Decoder(
                layer: .init(size: _modelSize, selfAttention: _attention, sourceAttention: _attention, feedForward: _feedForward, dropoutProb: _dropoutProbability), 
                layerCount: _layerCount)
            
            let _motionDense = Dense<Float>(reader: reader, config: config, scope: scope + "/motionDense")
            
            let _embedding = Embedding<Float>(reader: reader, config: config, scope: scope + "/embedding")
            // TODO: serialize PositionalEncoding
            // TODO: deserialize PositionalEncoding
            let _positionalEncoding = PositionalEncoding(size: _modelSize, dropoutProbability: _dropoutProbability)
            let _sourceEmbed = Sequential(_embedding, _positionalEncoding)

            let _mixtureModel = MotionGaussianMixtureModel(reader: reader, config: config, scope: scope + "/mixtureModel")
            
            self.init(encoder: _encoder, decoder: _decoder, embedding: _embedding, positionalEncoding: _positionalEncoding, 
                      motionDense: _motionDense, sourceEmbed: _sourceEmbed, mixtureModel: _mixtureModel, 
                      modelSize: _modelSize, nbJoints: _nbJoints, nbMixtures: _nbMixtures)
        } catch {
            // If checkpoint is invalid, throw the error and exit.
            print("Fail to load LangMotionTransformer from checkpoint. \(error)")
            throw error
        }
    }
}

In [None]:
let readModel = LangMotionTransformer(checkpoint: temporaryDirectory)