Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Possible memory leak in botorch.optim.optimize_acqf #641

Closed
mshvartsman opened this issue Dec 23, 2020 · 21 comments
Closed

[Bug] Possible memory leak in botorch.optim.optimize_acqf #641

mshvartsman opened this issue Dec 23, 2020 · 21 comments
Assignees
Labels
bug Something isn't working upstream issue

Comments

@mshvartsman
Copy link
Contributor

🐛 Bug

As far as I can tell botorch.optim.optimize_acqf leaves a tiny bit of memory behind somewhere. It seems worse for q-batched acquisition functions (at least, for qUCB and qEI) than analytic ones, and worse on ubuntu than OSX. Calls to fit_gpytorch_model and the acqf itself seem fine.

To reproduce

Sorry this is a bit long.

import torch
import numpy as np
import gpytorch
from botorch.models.gpytorch import GPyTorchModel
from botorch.fit import fit_gpytorch_model
from botorch.optim import optimize_acqf
from botorch.acquisition import (
    qUpperConfidenceBound,
    ExpectedImprovement,
    qExpectedImprovement,
)

from gpytorch.models import ApproximateGP
from gpytorch.variational import MeanFieldVariationalDistribution, VariationalStrategy

from tqdm import trange

# Haven't checked if this happens with non-variational GPs yet
class GPClassificationModel(ApproximateGP, GPyTorchModel):

    _num_outputs = 1

    def __init__(
        self, inducing_min, inducing_max, inducing_size=10,
    ):

        inducing_points = torch.linspace(
            inducing_min[0], inducing_max[0], inducing_size
        )

        variational_distribution = MeanFieldVariationalDistribution(
            inducing_points.size(0)
        )
        variational_strategy = VariationalStrategy(
            self,
            inducing_points,
            variational_distribution,
            learn_inducing_locations=False,
        )
        super(GPClassificationModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(ard_num_dims=1),
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
        return latent_pred

    def set_train_data(self, x, y):
        self.train_inputs = (x,)
        self.train_targets = y


bounds = torch.Tensor(np.r_[-1, 1])[:, None]
ntrials = 1000
restarts = 10
samps = 1000
q = 1
n = 10

# initialize
likelihood = gpytorch.likelihoods.BernoulliLikelihood()
model = GPClassificationModel(inducing_min=bounds[0], inducing_max=bounds[1])

acq = qUpperConfidenceBound(model=model, beta=3.98)

mll = gpytorch.mlls.VariationalELBO(likelihood, model, n)
x = torch.rand(size=(n,))
y = torch.randint_like(x, 0, 2, dtype=torch.long)
model.set_train_data(x, y)
model.train()

# just call something in a tight loop to see if memory grows
for i in trange(ntrials):
    # this call keeps memory steady
    # fit_gpytorch_model(mll)

    # this call keeps memory steady
    # _ = acq(x[:, None])

    # this call grows memory by a little bit every call
    new_x, batch_acq_values = optimize_acqf(
        acq_function=acq, bounds=bounds, q=q, num_restarts=restarts, raw_samples=samps,
    )

Running the above with mprof, here's what no leak looks like:
No leak

Here's what a leak on OSX looks like:
OSX leak

Here's what a leak on ubuntu looks like:
ubuntu leak

Expected Behavior

Expecting no memory leak here -- I'm trying to run some benchmarks, which means that I run many synthetic opt runs and anything long-running gets killed.

System information

  • botorch version: 0.3.3
  • gpytorch version: 1.3.0
  • pytorch version: 1.7.1
  • OS: OSX (mild apparent leak), ubuntu (worse apparent leak).
@mshvartsman mshvartsman added the bug Something isn't working label Dec 23, 2020
@saitcakmak
Copy link
Contributor

I have observed something similar in the past, but forgot to investigate further. Is it possible that the memory leak happens in calls to model.posterior() (which UCB would be doing)? That was my guess at the time.

I had a base model, for which I was generating fantasy models f_model = model.fantasize(...) and calling f_model.posterior(). The fantasy model was discarded after this call. This was repeated many times, keeping the base model fixed. I would eventually run into an OOM error. When I looked into it, I found that this part of the code was leaving some memory behind.

I have a feeling that gpytorch context managers, which botorch uses by default may be causing this. Let me see if I can verify this.

@saitcakmak
Copy link
Contributor

I can't seem to reproduce the memory leak with your code, but replacing the last part with the following produces a memory leak:

from botorch.models import SingleTaskGP
from botorch.sampling import SobolQMCNormalSampler

m2 = SingleTaskGP(
    torch.rand(100, 2),
    torch.randn(100, 1),
)

for i in trange(ntrials):
    fant = m2.fantasize(torch.rand(5, 1, 2), SobolQMCNormalSampler(5))
    z = torch.rand(1000, 5, 5, 1, 2)
    post = fant.posterior(z)

This leads to a memory leak as seen here:
Figure_1

If I instead disable the GPyTorch's fast predictive variances:

for i in trange(ntrials):
    with gpytorch.settings.fast_pred_var(False):
        fant = m2.fantasize(torch.rand(5, 1, 2), SobolQMCNormalSampler(5))
        z = torch.rand(1000, 5, 5, 1, 2)
        post = fant.posterior(z)

The leak seems to disappear (at least partially):
Figure_2

By default BoTorch uses gpytorch.settings.fast_pred_var(True), which caches some computations for future use. You could disable it and see if it resolves the memory leak.

@Balandat
Copy link
Contributor

I can repro this issue, looking into this further.

Am a little hamstrung by running into this error frequently, seems to be a known issue with mprof :( pythonprofilers/memory_profiler#163

Traceback (most recent call last):
  File "/Users/balandat/miniconda3/envs/botorch/bin/mprof", line 8, in <module>
    sys.exit(main())
  File "/Users/balandat/miniconda3/envs/botorch/lib/python3.8/site-packages/mprof.py", line 877, in main
    actions[get_action()]()
  File "/Users/balandat/miniconda3/envs/botorch/lib/python3.8/site-packages/mprof.py", line 279, in run_action
    mp.memory_usage(proc=p, interval=args.interval, timeout=args.timeout, timestamps=True,
  File "/Users/balandat/miniconda3/envs/botorch/lib/python3.8/site-packages/memory_profiler.py", line 367, in memory_usage
    stream.write("MEM {0:.6f} {1:.4f}\n".format(*mem_usage))
TypeError: format() argument after * must be an iterable, not NoneType

The suggestion to run in sudo doesn't fix this either

@mshvartsman
Copy link
Contributor Author

python -m mprof run blah seems to fix it 99% of the time for me.

@Balandat
Copy link
Contributor

The fact that Sait's example shows significant growth means that there must be some tensors / graph components of significant size that are not being garbage collected...

@mshvartsman
Copy link
Contributor Author

They seem to be different issues though -- disabling fast predictive variances doesn't fix my issue. Turning off all fast computations (with gpytorch.settings.fast_computations(False, False, False)) does, so now it's just about figuring out which of them it is :).

@Balandat
Copy link
Contributor

hmm with gpytorch.settings.fast_computations(False, False, False) doesn't fix the issue for me....

@Balandat
Copy link
Contributor

Balandat commented Dec 23, 2020

OK just tested this on an old linux machine, no memory leak there:

ubuntu_noleak

Here is the exact same code run on MacOS X 11.1 with the same config (python 3.8.5, torch 1.7.1, botorch 0.3.3, gpytorch 1.3.0):

macos_memleak_slow

Note that the Mac run is super slow (10X slower) compared to ubuntu - and this is on a modern 8-core Core i9 compared to a super old Dual-Core Core i3 on the linux machine. Also the memory allocation is much smaller (unclear from these profiles whether the ubuntu one is moving around, this might just not be visible on this scale. Something is definitely fishy here.

@Balandat
Copy link
Contributor

Hmm turns out that running this with python - m mprof run instead of mprof to circumvent the issues mentioned earlier is what caused this to be so slow. I was able to get a run in with just mprof run that looks a lot more reasonable in terms of the runtime. However, memory usage is still growing over time.

macos_memleak

@mshvartsman
Copy link
Contributor Author

Ok, there seem to be two separate issues:

  1. The Ubuntu situation seems to be hard to repro, I can't get it to come up again with the code I sent on the same machine after reinstalling gpytorch / botorch / pytorch. I can still get it to come up consistently in my code when I instantiate botorch objects as part of a bunch of other stuff, but let's call that one user error for now.
  2. The OSX issue seems pretty consistent and I can't make it go away by disabling fast computations or by running single-threaded with MKL_NUM_THREADS=1 OMP_NUM_THREADS=1. I'll take a deeper look next week unless someone else gets to it sooner.

@Balandat
Copy link
Contributor

@mshvartsman Any updates on this? On my end it's still reproducible on OSX but works fine on ubuntu.

@mshvartsman
Copy link
Contributor Author

Short version: no news.

Long version: linux issue still comes up when I run on very very many threads (~20+) and may be some weird edge case or bad interaction between threads and python multiprocessing. I'd say unlikely to come up in normal usage. No news on the mac side. Sorry for letting this drop.

@Balandat
Copy link
Contributor

Fair enough. This one seems nasty enough that I'm not very tempted to go real deep here. I'll leave the issue open for tracking though.

@esantorella
Copy link
Member

This is a PyTorch issue

Here’s what I see on OSX when running the original example (yep, it’s still happening):

image

Here is some pure PyTorch code that produces the same pattern (the first ~2s are setup time):

import torch
import time


def do_tensor_stuff(samps: int) -> None:
    """Similar to the content of qUpperConfidenceBound.forward."""
    samples = torch.rand((512, samps, 1, 1))
    obj = samples[:, :, :, 0]
    mean = obj.mean(dim=0)
    ucb_samples = mean + 3.98 * (obj - mean).abs()
    ucb_samples.max(dim=-1)[0].mean(dim=0)


if __name__ == "__main__":
    ntrials = 2000
    samps = 1000
    start = time.monotonic()

    for i in range(ntrials):
        do_tensor_stuff(samps)
    print(time.monotonic() - start)

image

I don’t think this would cause an OOM; tensors are getting appropriately deallocated

mprof measures the amount of memory reserved by the process and not available to the rest of the system; it doesn’t measure the amount of memory occupied by live Python objects or by C objects such as tensors. I checked what Python objects remain at the end of the loop pretty thoroughly using gc.get_objects and gc.DEBUG_LEAK, and I closely inspected acq to make sure we’re not repeatedly writing tensors to it. At the end of the loop, there’s nothing allocated that shouldn’t be

Also, I tried to get it to OOM by increasing the size of the tensors and running many iterations, and wasn’t able to cause an OOM. And I see that total time is very close to linear in the number of iterations. But if this is causing practical problems for anyone, of course it would be great to know that.

I think this has to do with too much memory being reserved for torch tensors, especially when a lot of small ones are created and destroyed. For related issues see here.

@Balandat
Copy link
Contributor

Balandat commented Jan 4, 2023

Thanks for the thorough investigation! Given the result, should we close out the issue as it doesn't seem to be a botorch or gpytorch issue and also doesn't seem to cause any concrete problems?

@esantorella
Copy link
Member

Yeah, I would consider this closed unless there is an example where this is causing a practical issue like an OOM or slowdown when memory gets high.

I have a couple follow-ups in mind that would reduce memory usage, but those should be in separate issues/PRs:

  • Some small optimizations
  • Using in-place operations instead of creating and destroying small tensors. This completely eliminates the growth of memory over time in the do_tensor_stuff example in the snippet above and reduces memory usage ~33%, but at a cost to readability. If I can find a realistic example where that will help, I'll put a PR up for discussion.

@Balandat
Copy link
Contributor

Balandat commented Jan 4, 2023

Using in-place operations instead of creating and destroying small tensors.

We'll have to be careful about doing that when on a path that should be differentiable - autograd doesn't play well with in-place tensor operations.

@AdrianSosic
Copy link
Contributor

AdrianSosic commented Oct 19, 2023

Hi! I just noticed the same behavior in my code and found this conversation (and also followed up on some of the other memory-related issues linked here). The problem was reported to me by someone using my code and they actually told me that their process was killed after some time due to OOM errors.

I will try to reproduce their settings and see if pytorch's memory allocation was really the cause but, in the meantime, I wanted to ask if you have also noticed something similar lately? I mean, this thread was closed because it seemed like it's not a leak but rather controlled by the pytorch's memory scheduler – but how can we be certain?

My question: is there a way to explicitly trigger a release of the allocated (but unused) memory? On the GPU, this seems to be possible via torch.cuda.empty_cache, but is there a way for the CPU? If there was, one could easily use this to rule out it's a leak.

@saitcakmak
Copy link
Contributor

Hi @AdrianSosic. We're not aware of any memory leaks with BoTorch at this time. Note that memory usage will increase super-linearly in number of train & test points when using GPs, so it is not hard to OOM with BoTorch.

My question: is there a way to explicitly trigger a release of the allocated (but unused) memory? On the GPU, this seems to be possible via torch.cuda.empty_cache, but is there a way for the CPU? If there was, one could easily use this to rule out it's a leak.

Python's gc.collect() can sometimes help but it works in mysterious ways. I am not aware of a PyTorch specific one.

@AdrianSosic
Copy link
Contributor

Hi @saitcakmak, thanks for your answer!

Just to avoid potential misunderstandings here: I'm well aware of the fact that there is a super-linear memory increase with GPs (i.e. the typical O(N^3) consumption in case of exact inference). But this is not my point. Instead, I was referring to @esantorella's post, which shows that memory allocation still increases even when all objects of a finished operation have been properly disallocated.

As far as I understood, the current hypothesis is that this is caused by pytorch's memory management and it would not cause any OOM errors, i.e. the memory would be released before one runs into OOM. @esantorella tried to verify this empirically:

Also, I tried to get it to OOM by increasing the size of the tensors and running many iterations, and wasn’t able to cause an OOM. And I see that total time is very close to linear in the number of iterations. But if this is causing practical problems for anyone, of course it would be great to know that.

However, since not finding a positive example is no guarantee that this will never happen, this is just speculative. What I tried to explain in my previous post is that it seems some colleagues of mine now actually ran exactly into such a situation. Since the problem did neither occur on my system nor with my own code, I haven't been able to reproduce it so far, but I will try to do so with their help. So in case anyone of you has any new insights into the mechanics of the scheduler that have not been mentioned so far, I'd be happy to hear them =)

@esantorella
Copy link
Member

To distinguish a true memory leak from the Python process reserving more memory than it needs, I would look at the contents of gc.get_objects() to see what tensors are being tracked by Python's garbage collector. If there's a true memory leak, tensors will keep being tracked by the garbage collector even after they are not needed

If you wanted the number and sizes of tensors that are currently tracked, you could do the following:

import gc
from collections import Counter

print(Counter((t.shape for t in gc.get_objects() if isinstance(t, torch.Tensor))))

If you wrap all of your code in a function and there are still tensors in gc.get_objects() after the function exits, that's a good indication of a memory leak. Similarly, if there number of tensors grows as your run a loop (rather than just the size of the tensors growing), that may also indicate a leak. On the other hand, if there is nothing large in gc.get_objects() but you see the Python process using a lot of memory, that may indicate that the issue is more with the Python process requesting more memory than it needs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working upstream issue
Projects
None yet
Development

No branches or pull requests

5 participants