Skip to content
This repository has been archived by the owner on Jul 1, 2023. It is now read-only.

S4TF speed #630

Closed
awav opened this issue Jan 15, 2020 · 9 comments
Closed

S4TF speed #630

awav opened this issue Jan 15, 2020 · 9 comments

Comments

@awav
Copy link
Contributor

awav commented Jan 15, 2020

Hello everyone,

I'm benchmarking Kalman Filter on S4TF, and here are results:

Language time
Matlab ~0.1s
Python (TF2) ~2.5s
Swift ~10s

The question is why is that soo slow, am I doing something wrong?

Swift code (Release, Development toolchain 2019-12-23):

import TensorFlow
import QuartzCore

@differentiable
func kalman(
    observed: Tensor<Float>,
    likelihood: Tensor<Float>,
    prior: Tensor<Float>,
    length: Tensor<Float>
) -> Tensor<Float> {
    typealias T = Tensor<Float>
    let lam = sqrt(5.0) / length
    let lam1 = -3 * lam
    let lam2 = -(3 * 5) / length
    let lam3 = -pow(lam, 3)
    let z = T(0)
    let o = T(1)
    let F = T([T([z, o, z]), T([z, z, o]), T([lam3, lam2, lam1])])
    let H = T([T([o, z, z])])
    let kappa = (5.0 / 3.0) * prior / length.squared()
    let Pinf = T([T([prior,  z,     -kappa]),
                  T([z,      kappa, z]),
                  T([-kappa, z,     prior / pow(length, 4) * 25])])
    
    let A = F.squared()  // S4TF doesn't support matrix exponential op, use square for now (this is incorrect)

    var m = T(zeros: [F.shape[0], 1])
    var P = Pinf
    var lml = z

    for k in 0..<observed.shape[0] {
        m = matmul(A, m)
        P = Pinf + matmul(A, matmul(P - Pinf, transposed: false, A, transposed: true))
        let S = likelihood + matmul(H, matmul(P, transposed: false, H, transposed: true))
        let v = observed[k] - matmul(H, m)
        let lz = -0.5 * (v.squared() / S + log(2 * Float.pi * S))
        let K = matmul(P, transposed: false, H, transposed: true) / S
        m = m + matmul(K, v)
        P = P - matmul(K, matmul(S, transposed: false, K, transposed: true))
        lml = lml + lz[0, 0]
    }

    return lml
}

func testKalmanFilter() {
    let num = 50000
    let y = Tensor<Float>(randomNormal: [num])
    let likelihood = Tensor<Float>(0.1)
    let prior = Tensor<Float>(1.0)
    let length = Tensor<Float>(1.0)

    kalman(observed: y, likelihood: likelihood, prior: prior, length: length)
}

func executionTimeInterval(block: () -> ()) -> CFTimeInterval {
    print("Start execution")
    let start = CACurrentMediaTime()
    block()
    let end = CACurrentMediaTime()
    print("Finished execution")
    return end - start
}

let time = executionTimeInterval {
    testKalmanFilter()
}

print("Elapsed time: \(time)")

Python TensorFlow (version 2.1, pip installed, no GPU, python 3.6)

import tensorflow as tf
import numpy as np
from time import time
from datetime import datetime
pi = 3.141592653589793
flttype = tf.float32

num_steps = 50000

np.random.seed(123)
y = tf.constant(np.random.randn(num_steps), dtype=flttype)
# y = tf.dtypes.cast(tf.linspace(0.0, 1.0, num_steps), dtype=flttype)

var_y = tf.Variable(1e-1)
len_f = tf.Variable(1.0)
var_f = tf.Variable(1.0)


NoneTensorShape = tf.TensorShape((None, ))
input_signature = [
    tf.TensorSpec(NoneTensorShape, dtype=flttype),
    tf.TensorSpec((), dtype=flttype),
    tf.TensorSpec((), dtype=flttype),
    tf.TensorSpec((), dtype=flttype)
]


@tf.function(input_signature=input_signature)
def kalman_filter(y_observed, var_likelihood, var_prior, len_prior):
    lam = 5.0**0.5 / len_prior
    F = tf.concat(
        [[[0.0, 1.0, 0.0]], [[0.0, 0.0, 1.0]], [[-lam**3, -3 * lam**2, -3 * lam]]], axis=0
    )
    H = tf.constant([[1.0, 0.0, 0.0]], dtype=flttype)
    kappa = 5.0 / 3.0 * var_prior / len_prior**2
    Pinf = tf.concat(
        [
            [[var_prior, 0.0, -kappa]], [[0.0, kappa, 0.0]],
            [[-kappa, 0.0, 25.0 * var_prior / len_prior**4.0]]
        ],
        axis=0
    )
    m = tf.zeros([F.shape[0], 1], dtype=flttype)
    P = Pinf
    N = tf.shape(y_observed)[0]
    A = tf.linalg.expm(F)
    log_marg_lik = tf.constant(0.0)
    with tf.name_scope("kalman_loop"):
        for k in tf.range(N):
            with tf.name_scope("kalman_body"):
                m = A @ m
                P = A @ tf.linalg.matmul((P - Pinf), A, transpose_b=True) + Pinf
                S = var_likelihood + H @ tf.linalg.matmul(P, H, transpose_b=True)
                v = y_observed[k] - H @ m
                v2 = tf.square(v)
                lZ = -0.5 * (v2 / S + tf.math.log(2 * pi * S))
                log_marg_lik += lZ[0, 0]
                K = tf.linalg.matmul(P, H, transpose_b=True) / S
                m = m + K @ v
                P = P - K @ tf.linalg.matmul(S, K, transpose_b=True)
    return log_marg_lik


@tf.function(input_signature=input_signature)
def gradient_step(y_observed, var_likelihood, var_prior, len_prior):
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch([var_likelihood, var_prior, len_prior])
        neg_log_marg_lik = kalman_filter(y_observed, var_likelihood, var_prior, len_prior)
    dvar_y, dvar_f, dlen_f = tape.gradient(neg_log_marg_lik, [var_likelihood, var_prior, len_prior])
    return neg_log_marg_lik, dvar_y, dvar_f, dlen_f

nlml = kalman_filter(y[:10], var_y, var_f, len_f)  # set up the graph
nlml, d1, d2, d3 = gradient_step(y[:10], var_y, var_f, len_f)  # set up the graph

t1 = time()
with tf.name_scope("forward"):
    nlml = kalman_filter(y, var_y, var_f, len_f)  # reuse graph
t2 = time()
print(f"Forward computation time: {t2-t1:2.2f}")

Matlab

num_steps = 50000;

rng(123)
y = randn(num_steps, 1);
% y = linspace(0, 1, num_steps)';

var_y = 1e-1;
len_f = 1.0;
var_f = 1.0;

tic
nlml = kalman_filter(y, var_y, var_f, len_f)
toc

function log_marg_lik = kalman_filter(y_observed, var_likelihood, var_prior, len_prior)
    lam = 5.0^0.5 / len_prior;
    F = [[0.0, 1.0, 0.0];
         [0.0, 0.0, 1.0];
         [-lam^3, -3 * lam^2, -3 * lam]];
    H = [1.0, 0.0, 0.0];
    kappa = 5.0 / 3.0 * var_prior / len_prior ^ 2;
    Pinf = [[var_prior, 0.0, -kappa];
            [0.0, kappa, 0.0];
            [-kappa, 0.0, 25.0 * var_prior / len_prior^4.0]];
    m = zeros(size(F,1), 1);
    P = Pinf;
    N = size(y_observed, 1);
    A = expm(F);
    log_marg_lik = 0.0;
    for k = 1:N
        m = A * m;
        P = A * (P-Pinf) * A' + Pinf;
        S = var_likelihood + H * P * H';
        v = y_observed(k) - H * m;
        lZ = -v ^ 2 / S / 2 - log(2 * pi * S) / 2;
        log_marg_lik = log_marg_lik + lZ;
        K = P * H' / S;
        m = m + K * v;
        P = P - K * S * K';
    end
end
@dan-zheng
Copy link
Member

What compilation flags did you pass to swift? Did you try swift -Ounchecked, which disables preconditions?

@awav
Copy link
Contributor Author

awav commented Jan 15, 2020

@dan-zheng, how can I pass this option to the Xcode?

I'm not using terminal because of this issue (but that's a different story)

→ swift build
dyld: Library not loaded: @rpath/libswiftCore.dylib
  Referenced from: /Library/Developer/Toolchains/swift-tensorflow-DEVELOPMENT-2019-12-23-a.xctoolchain/usr/bin/swift-build
  Reason: image not found
Abort trap: 6

@dan-zheng
Copy link
Member

@dan-zheng, how can I pass this option to the Xcode?

I forget, but you can Google "Xcode pass Swift compiler flags".


The dyld: Library not loaded: @rpath/libswiftCore.dylib issue and workaround are known: tensorflow/swift#347

Workaround: manually add an RPATH to SwiftPM binaries.

VERSION=swift-tensorflow-DEVELOPMENT-2019-12-23-a
sudo install_name_tool -add_rpath "@executable_path/../lib/swift/macosx" /Library/Developer/Toolchains/$VERSION.xctoolchain/usr/bin/swift-build
sudo install_name_tool -add_rpath "@executable_path/../lib/swift/macosx" /Library/Developer/Toolchains/$VERSION.xctoolchain/usr/bin/swift-package
sudo install_name_tool -add_rpath "@executable_path/../lib/swift/macosx" /Library/Developer/Toolchains/$VERSION.xctoolchain/usr/bin/swift-run
sudo install_name_tool -add_rpath "@executable_path/../lib/swift/macosx" /Library/Developer/Toolchains/$VERSION.xctoolchain/usr/bin/swift-test

@awav
Copy link
Contributor Author

awav commented Jan 15, 2020

I forget, but you can Google "Xcode pass Swift compiler flags".

I did that :), I found posts for old Xcode, but new one has different interface.

Workaround: manually add an RPATH to SwiftPM binaries.

Thanks for the tip. Now it works, and bench time didn't change much with -Ounchecked.

→ swift run -c release -Xswiftc -Ounchecked Kalman
Start execution
Finished execution
Elapsed time: 9.111591231019702

@pschuh
Copy link
Contributor

pschuh commented Jan 16, 2020

Hi, Just a note here.
I think you'll be a lot happier if you don't use the Tensorflow part, but just the AD part. We still have derivatives defined for many of those operators (over floats). Tensorflow always constructs objects (among other overheads) for every operation with a fixed overhead. This means that for very simple things (like multiplying 2x2 matricies etc), you're incurring a very large overhead. Basically, the rule of thumb is to only use Tensorflow for things that are big enough to perform well on an accelerator.

@awav
Copy link
Contributor Author

awav commented Jan 16, 2020

@pschuh, thanks a lot! It does make sense. Does it mean that S4TF behaves like TensorFlow 2.0 in Python without applying tf.function? In graph mode, TensorFlow builds computation ops only once and reuses it for different inputs.

I think you'll be a lot happier if you don't use the Tensorflow part, but just the AD part. We still have derivatives defined for many of those operators (over floats)

Is there a matmul and transpose for arrays in swift? I use 3x3 matrices :)

@awav
Copy link
Contributor Author

awav commented Jan 16, 2020

@dan-zheng, @pschuh Not related to this thread, but the situation with Swift as differentiable programming language reminds me of the essay "Too much calculus" by Gilbert Strang. I do want to do research using Swift because I like swift-python compatibility, language styling and its direction of development (mainly set by Google), but the necessary and the lacking bit is efficient linear algebra operations. I thought that TensorFlow integrates into Swift smoothly without any additional cost. If that's fixable then great, if not, then swift needs support for matrix (tensor) operations from BLAS, LAPACK (Apple's Accelerate). This is vital for all types of research, it is not enough to say that swift supports auto-differentiation out-of-the-box (although, this is a very cool thing).
For example, I found this implementation https://github.com/AlexanderTar/LASwift (based on Apple Accelerate), according to the benchmark it has reasonable speed, but it has no gradients, no batching, and interface is very restrictive.

@dan-zheng
Copy link
Member

@dan-zheng, @pschuh Not related to this thread, but the situation with Swift as differentiable programming language reminds me of the essay "Too much calculus" by Gilbert Strang. I do want to do research using Swift because I like swift-python compatibility, language styling and its direction of development (mainly set by Google), but the necessary and the lacking bit is efficient linear algebra operations. I thought that TensorFlow integrates into Swift smoothly without any additional cost. If that's fixable then great, if not, then swift needs support for matrix (tensor) operations from BLAS, LAPACK (Apple's Accelerate). This is vital for all types of research, it is not enough to say that swift supports auto-differentiation out-of-the-box (although, this is a very cool thing).

Thanks for your feedback! Please create a new issue next time. 🙂

I guess everyone has different needs and priorities: you may be super interested in differentiable linear algebra operations, but our team is working on things important to us, like:

  • Swift models and benchmarking
  • A faster, accelerable Tensor implementation
  • Meta-programming for machine learning use cases
  • Automatic differentiation support for more language features.

Our team is trying to tackle issues that are important for our use cases (deep learning models). Regarding open-source work, we try to focus on fundamental issues that only we have the resources to solve. Other issues, like adding and improving APIs, are perfect for the open-source community to contribute to.

I don't think we currently have the motivation or bandwidth for a team member to work on tensorflow/swift-apis full-time (e.g. to add linear algebra operations, backed by BLAS/LAPACK), but that may change in the future.


I think a productive way for community members like you to contribute is "adding the APIs that you want to use". You've been doing a great job of kickstarting and driving the "linear algebra operations" effort, in TF-980 and #562: if it weren't for you, we probably wouldn't have any linear algebra APIs right now, since it's not a priority for our team. This is a great example of how open-source contributors like you can influence an open-source project.

Perhaps tensorflow/swift-apis in its current state isn't ideal for you due to more fundamental issues, like poor performance or the lack of multiple backend support (e.g. supporting BLAS/LAPACK as a Tensor backend instead of TensorFlow). I feel these underlying issues are, and should be, our team's focuses.

For example, I found this implementation https://github.com/AlexanderTar/LASwift (based on Apple Accelerate), according to the benchmark it has reasonable speed, but it has no gradients, no batching, and interface is very restrictive.

Since differentiable programming is designed as a general Swift language feature, it's totally possible to extend an existing library like https://github.com/AlexanderTar/LASwift and make linear algebra operations differentiable (using @differentiable and @derivative).

However, for the most impact, and to unify the community's efforts, I would recommend finding ways to contribute to tensorflow/swift-apis. Related mailing list discussion.

apple/swift-numerics is another Swift ecosystem project that plans to add support for ShapedArray and linear algebra operations. I think there's definitely room for collaboration with that project.

@awav
Copy link
Contributor Author

awav commented Jan 17, 2020

Please create a new issue next time.

Ok, I will make a separate post in the S4TF mailing list.

Perhaps tensorflow/swift-apis in its current state isn't ideal for you due to more fundamental issues, like poor performance or the lack of multiple backend support (e.g. supporting BLAS/LAPACK as a Tensor backend instead of TensorFlow)

I don't think we currently have the motivation or bandwidth for a team member to work on tensorflow/swift-apis full-time (e.g. to add linear algebra operations, backed by BLAS/LAPACK)

The TensorFlow provides lots of functions from LAPACK and BLAS, and it also has a GPU backend with CUBLAS support. This unification of computational backends makes TensorFlow a nice choice. The best solution for me would be to generate TensorFlow swift bindings so that S4TF could call C++ kernel implementations directly without any extra cost.

You've been doing a great job of kickstarting and driving the "linear algebra operations" effort

Thanks! :)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants