Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

STS models running slower on GPUs #1395

Open
WillianFuks opened this issue Aug 6, 2021 · 2 comments
Open

STS models running slower on GPUs #1395

WillianFuks opened this issue Aug 6, 2021 · 2 comments

Comments

@WillianFuks
Copy link

This simple example running on Colab shows this issue:

import numpy as np
import tensorflow_probability as tfp
import tensorflow as tf
from time import time


tfd = tfp.distributions
tfb = tfp.bijectors

ds = np.random.rand(1300, 5)

components = []
ots = ds[:1000, 0].astype(np.float32)

obs_sd = ots.std()

level_component = tfp.sts.LocalLevel(observed_time_series=ots)
components.append(level_component)

linear_component = tfp.sts.LinearRegression(design_matrix=ds[:,1:].astype(np.float32))
components.append(linear_component)

model = tfp.sts.Sum(components, observed_time_series=ots)

t0 = time()
optimizer = tf.optimizers.Adam(learning_rate=0.1)
variational_steps = 200
variational_posteriors = tfp.sts.build_factored_surrogate_posterior(model=model)

@tf.function()
def _run_vi():
    tfp.vi.fit_surrogate_posterior( 
        target_log_prob_fn=model.joint_log_prob(
            observed_time_series=ots
        ),
        surrogate_posterior=variational_posteriors,
        optimizer=optimizer,
        num_steps=variational_steps
    )

    samples = variational_posteriors.sample(100)
    return samples, None
_run_vi()
t1 = time()

print(t1 - t0)

Running on CPU takes a few seconds whereas running on GPU takes minutes. I also tested the same procedure using TFP's example notebook on Colab and again running on GPU also took longer (between 2~3x).

I tried testing with bigger datasets to see if data volume was the issue but as I increased it ten fold the GPU could no longer finish its process on Colab.

Also tested on previous versions of Tensorflow and Probability but the issue remained. Is there something that changed that made GPUs slower?

Thanks in advance!

@davmre
Copy link
Contributor

davmre commented Aug 9, 2021

Thanks for filing this! My impression is that STS models on GPU have essentially always been slow; I don't think it's anything recent. We haven't pushed hard on this since it hasn't been a blocker for our internal use cases, but I think there's almost certainly room for improvement.

Some general context: STS models in general tend to have lots of little ops, which is a different performance profile than your average NN model where the computation is dominated by a few big matmuls. Since there's some overhead every time an op is interpreted, using XLA compilation (jit_compile=True arg to tf.function) can be particularly helpful for STS models, though I'd doubt that this by itself would be a panacea for GPU performance.

Without having done any profiling, one initial hypothesis might be that some there's some op that (for some reason) is always executed on the CPU, so using the GPU for everything else just adds a data-transfer bottleneck at every step, negating any improvement from the GPU acceleration. The tools at https://www.tensorflow.org/guide/gpu_performance_analysis would probably be useful for investigating this.

I do think this is an important issue that we'll want to understand better, though I don't personally expect to be able to put a lot of time into this in the next few weeks.

@WillianFuks
Copy link
Author

WillianFuks commented Aug 16, 2021

Thanks @davmre for the reply!

I did have the impression it was running faster before but maybe I made a mistake between setting variational inference and hmc, not sure now.

Hoping the support for GPU will eventually happen :)

Best,

Will

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants