Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM with RotatE Model when Early Stopping #1251

Open
3 tasks done
lrey-civetta opened this issue Mar 28, 2023 · 13 comments
Open
3 tasks done

OOM with RotatE Model when Early Stopping #1251

lrey-civetta opened this issue Mar 28, 2023 · 13 comments
Labels
bug Something isn't working

Comments

@lrey-civetta
Copy link

lrey-civetta commented Mar 28, 2023

Describe the bug

Hi there,

Related to #129 and #433. I'm hitting memory errors when running a RotatE model on the PrimeKG dataset. Happening consistently when embedding_dim is set higher than ~512 and/or num_negs_per_pos is set higher than ~100. Only happens when early stopping is enabled. Setting batch_size lower decreases the chance that this happens. However, I would like to use automatic batch size optimization.

The error could be just that the GPU I have (Nvidia A10) isn't large enough for this model complexity. However, I want to confirm that this is the issue, rather than a recurrence of the OOM-during-early-stopping bug. (As this only happens during early stopping.) Is there a way to differentiate between a 'your-model-is-too-big-for-your-GPU' error and this early stopping-based OOM bug?

Error:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 17.88 GiB (GPU 0; 22.20 GiB total capacity; 1.86 GiB already allocated; 1.11 GiB free; 20.35 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

How to reproduce

Train, val, and test are all custom partitions of the entire PrimeKG dataset, as type TriplesFactory.

pipeline_result = pipeline(
    training=train,
    validation=val,
    testing=test,

    model='RotatE',
    model_kwargs={
        "embedding_dim": 580,
        "random_seed": 42
    },

    training_loop='sLCWA',
    training_kwargs={
        "num_epochs": 1000, # ceiling
        "batch_size": None,

        "checkpoint_name": "pykeen-checkpoint.pt",
        "checkpoint_frequency": 0,
        "checkpoint_directory": "/opt/ml/checkpoints"
    },

    optimizer='Adagrad',
    optimizer_kwargs={
        'lr': 0.035
    },

    loss='NSSA',

    negative_sampler='basic',
    negative_sampler_kwargs={
        "num_negs_per_pos": 108
    },

    result_tracker='MLFlow',
    result_tracker_kwargs='postgresql://......',

    random_seed=42,

    stopper='early',
    stopper_kwargs={
        "frequency": 10,
        "patience": 2,
        "relative_delta": 0.005,
        "metric": "hits_at_k"
    },

    evaluator='RankBasedEvaluator',
    evaluator_kwargs={
        "filtered": True,
    },

    use_testing_data=True,
    use_tqdm=True
)

Environment

Run inside an AWS ml.g5.xlarge instance with 24gb GPU memory (Nvidia A10 GPU).

Using PyKeen v1.10.1. Output of python -m pykeen version:

Key Value
OS posix
Platform Linux
Release 5.15.0-67-generic
Time Tue Mar 28 14:25:36 2023
Python 3.10.6
PyKEEN 1.10.1
PyKEEN Hash UNHASHED
PyKEEN Branch
PyTorch 1.13.1+cu117
CUDA Available? false
CUDA Version 11.7
cuDNN Version 8500

Additional information

No response

Issue Template Checks

  • This is not a feature request (use a different issue template if it is)
  • This is not a question (use the discussions forum instead)
  • I've read the text explaining why including environment information is important and understand if I omit this information that my issue will be dismissed
@lrey-civetta lrey-civetta added the bug Something isn't working label Mar 28, 2023
@mberr
Copy link
Member

mberr commented Mar 28, 2023

Hi @lrey-civetta ,

would you mind sharing a full traceback of the error? This makes it easier to understand from where this error is coming.

Also, can you confirm that your process is the only process accessing this GPU while training (i.e., there is no other program running another in parallel accessig the same GPU)?

@lrey-civetta
Copy link
Author

lrey-civetta commented Mar 28, 2023

Sure, here you go:

File "/usr/local/lib/python3.10/dist-packages/pykeen/pipeline/api.py", line 1546, in pipeline
stopper_instance, configuration, losses, train_seconds = _handle_training(
File "/usr/local/lib/python3.10/dist-packages/pykeen/pipeline/api.py", line 1190, in _handle_training
losses = training_loop_instance.train(
File "/usr/local/lib/python3.10/dist-packages/pykeen/training/training_loop.py", line 378, in train
result = self._train(
File "/usr/local/lib/python3.10/dist-packages/pykeen/training/training_loop.py", line 735, in _train
callback.post_epoch(epoch=epoch, epoch_loss=epoch_loss)
File "/usr/local/lib/python3.10/dist-packages/pykeen/training/callbacks.py", line 443, in post_epoch
callback.post_epoch(epoch=epoch, epoch_loss=epoch_loss, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pykeen/training/callbacks.py", line 367, in post_epoch
if self.stopper.should_stop(epoch):
File "/usr/local/lib/python3.10/dist-packages/pykeen/stoppers/early_stopping.py", line 230, in should_stop
metric_results = self.evaluator.evaluate(
File "/usr/local/lib/python3.10/dist-packages/pykeen/evaluation/evaluator.py", line 213, in evaluate
rv = evaluate(
File "/usr/local/lib/python3.10/dist-packages/pykeen/evaluation/evaluator.py", line 687, in evaluate
relation_filter = _evaluate_batch(
File "/usr/local/lib/python3.10/dist-packages/pykeen/evaluation/evaluator.py", line 760, in _evaluate_batch
scores = model.predict(hrt_batch=batch, target=target, slice_size=slice_size, mode=mode)
File "/usr/local/lib/python3.10/dist-packages/pykeen/models/base.py", line 481, in predict
return self.predict_h(hrt_batch, **kwargs, heads=ids)
File "/usr/local/lib/python3.10/dist-packages/pykeen/models/base.py", line 374, in predict_h
scores = self.score_h(rt_batch, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pykeen/models/nbase.py", line 527, in score_h
scores=self.interaction.score(h=h, r=r, t=t, slice_size=slice_size, slice_dim=1),
File "/usr/local/lib/python3.10/dist-packages/pykeen/nn/modules.py", line 265, in score
return self(h=h, r=r, t=t)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pykeen/nn/modules.py", line 412, in forward
return self.__class__.func(**self._prepare_for_functional(h=h, r=r, t=t))
File "/usr/local/lib/python3.10/dist-packages/pykeen/nn/functional.py", line 576, in rotate_interaction
return negative_norm(h - t, p=2, power_norm=False)

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 17.88 GiB (GPU 0; 22.20 GiB total capacity; 1.86 GiB already allocated; 1.11 GiB free; 20.35 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

@lrey-civetta
Copy link
Author

And no, nothing else is running. This is actually being spun up on a fresh AWS instance each time I run it.

@mberr
Copy link
Member

mberr commented Mar 28, 2023

From your config, it looks like you are using automatic batch size selection/maximization for training, too

    training_kwargs={
        "batch_size": None, ...
    },

While in theory this shouldn't cause OOM's in evaluation, I wonder whether the problem persists if you manually set a lower training batch size than the maximum possible one.

@lrey-civetta
Copy link
Author

I can confirm that using both "batch_size": None and "batch_size": 4096 throws the error. The auto batch size selection generally selected 16384.

@mberr
Copy link
Member

mberr commented Mar 28, 2023

Hm, okay right away I do not see where this is coming from.

If you want to dig deeper into it, you could try to get more information about the allocated GPU memory / the allocated tensors before the crash. One easy way would be to define a custom Stopper class:

from pykeen.stoppers import EarlyStopper

class MemoryLoggingEarlyStopper(EarlyStopper):
  def should_stop(self, epoch: int, *, mode: Optional[InductiveMode] = None) -> bool:
     print(torch.cuda.memory_summary())
    # more info in this thread + the linked post:
    # https://discuss.pytorch.org/t/list-all-the-tensors-and-their-memory-allocation/144108
    return super().should_stop(epoch=epoch, mode=mode)

and provide this as stopper:

pipeline_result = pipeline(
  stopper=MemoryLoggingEarlyStopper,
  ..., # rest remains unchanged
)

@lrey-civetta
Copy link
Author

Thanks! I'll give this a shot.

@LizzAlice
Copy link
Contributor

I have this problem too, but with ConvKB. Will look into this some more in the following week, just wanted to let you know.

@LizzAlice
Copy link
Contributor

I didn't really have to time to look deeper into this, but I noticed that the early stopper also has the batch_size argument. Is it possible that automatic memory optimization is not used in the early stopper?

@LizzAlice
Copy link
Contributor

LizzAlice commented Jun 29, 2023

So I tried the command that you gave.
I am on pykeen version 1.13.1.
The code I was executing was

pipeline_result = pipeline(
    dataset="WN18RR",
    dataset_kwargs={"create_inverse_triples": True},
    model="ComplEx",
    loss="MarginRankingLoss",
    optimizer="Adadelta",
    negative_sampler="BasicNegativeSampler",
    random_seed=42,
    device="gpu",
    epochs=1000,
    stopper=stopper,
    stopper_kwargs={
        "patience": 2,
        "frequency": 50,
        "relative_delta": 0.001,
        "metric": "adjusted_arithmetic_mean_rank_index",
    },
    evaluation_fallback=True,
)

When I use the early stopper, I get an OOM Error, if I don't, there is no error.
Here is the output using the code you suggested:

File "/home/test/Individual.py", line 81, in execute
    pipeline_result = pipeline(
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/pipeline/api.py", line 1546, in pipeline
    stopper_instance, configuration, losses, train_seconds = _handle_training(
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/pipeline/api.py", line 1190, in _handle_training
    losses = training_loop_instance.train(
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/training/training_loop.py", line 378, in train
    result = self._train(
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/training/training_loop.py", line 735, in _train
    callback.post_epoch(epoch=epoch, epoch_loss=epoch_loss)
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/training/callbacks.py", line 443, in post_epoch
    callback.post_epoch(epoch=epoch, epoch_loss=epoch_loss, **kwargs)
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/training/callbacks.py", line 367, in post_epoch
    if self.stopper.should_stop(epoch):
  File "/home/test/Individual.py", line 19, in should_stop   
    return super().should_stop(epoch=epoch, mode=mode)
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/stoppers/early_stopping.py", line 230, in should_stop
    metric_results = self.evaluator.evaluate(
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/evaluation/evaluator.py", line 213, in evaluate
    rv = evaluate(
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/evaluation/evaluator.py", line 687, in evaluate
    relation_filter = _evaluate_batch(
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/evaluation/evaluator.py", line 760, in _evaluate_batch
    scores = model.predict(hrt_batch=batch, target=target, slice_size=slice_size, mode=mode)
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/models/base.py", line 481, in predict
    return self.predict_h(hrt_batch, **kwargs, heads=ids)
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/models/base.py", line 374, in predict_h
    scores = self.score_h(rt_batch, **kwargs)
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/models/nbase.py", line 527, in score_h
    scores=self.interaction.score(h=h, r=r, t=t, slice_size=slice_size, slice_dim=1),  
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/nn/modules.py", line 265, in score
    return self(h=h, r=r, t=t)
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/nn/modules.py", line 412, in forward
    return self.__class__.func(**self._prepare_for_functional(h=h, r=r, t=t))
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/pykeen/nn/modules.py", line 544, in func
    return torch.real(einsum("...d, ...d, ...d -> ...", h, r, torch.conj(t)))
  File "/home/anaconda3/envs/pykeen/lib/python3.10/site-packages/torch/functional.py", line 378, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 30.95 GiB (GPU 0; 44.56 GiB total capacity; 500.56 MiB already allocated; 12.08 GiB free; 31.37 GiB reserved in total by PyTorch)

|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 1            |        cudaMalloc retries: 1         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  508003 KB |   32345 MB |  457912 GB |  457911 GB |
|       from large pool |  507789 KB |   32343 MB |  444670 GB |  444669 GB |
|       from small pool |     214 KB |       4 MB |   13242 GB |   13242 GB |
|---------------------------------------------------------------------------|
| Active memory         |  508003 KB |   32345 MB |  457912 GB |  457911 GB |
|       from large pool |  507789 KB |   32343 MB |  444670 GB |  444669 GB |
|       from small pool |     214 KB |       4 MB |   13242 GB |   13242 GB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   32124 MB |   32368 MB |  381608 MB |  349484 MB |
|       from large pool |   32122 MB |   32364 MB |  376588 MB |  344466 MB |
|       from small pool |       2 MB |       6 MB |    5020 MB |    5018 MB |
|---------------------------------------------------------------------------|
| Non-releasable memory |   31627 MB |   31763 MB |   37093 GB |   37062 GB |
|       from large pool |   31626 MB |   31761 MB |   19975 GB |   19944 GB |
|       from small pool |       1 MB |       3 MB |   17117 GB |   17117 GB |
|---------------------------------------------------------------------------|
| Allocations           |      28    |      55    |  119109 K  |  119109 K  |
|       from large pool |       8    |      15    |    7335 K  |    7335 K  |
|       from small pool |      20    |      47    |  111773 K  |  111773 K  |
|---------------------------------------------------------------------------|
| Active allocs         |      28    |      55    |  119109 K  |  119109 K  |
|       from large pool |       8    |      15    |    7335 K  |    7335 K  |
|       from small pool |      20    |      47    |  111773 K  |  111773 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |       9    |      16    |    6312    |    6303    |
|       from large pool |       8    |      13    |    3802    |    3794    |
|       from small pool |       1    |       3    |    2510    |    2509    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       9    |      15    |   54016 K  |   54016 K  |
|       from large pool |       1    |       6    |     140 K  |     140 K  |
|       from small pool |       8    |      13    |   53875 K  |   53875 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

I am not quite sure how to interpret that, however.

@mberr
Copy link
Member

mberr commented Sep 15, 2023

Hi @lrey-civetta @LizzAlice ,

you may want to try again with the fixed brought to master by 2b4b776

@ddofer
Copy link

ddofer commented Apr 21, 2024

I am still getting the error discussed here. Apple MPS silicon. No warnings shown, it just crashes (OOM) in evaluation. Happens with TransE, RotE..

It does NOt happen when running on CPU.

@ddofer
Copy link

ddofer commented Jun 4, 2024

Crash happens in CPU only mode as well as mps, and with automatic batch size search disabled. Crash happensi n eval stage.

tf = TriplesFactory.from_labeled_triples(
  df_kg.rename(columns={"SUBJ_NAME":"source",
                        "PREDICATE":"type",
                        "OBJ_NAME":"target"})[["source", "type", "target"]].values ,
)
training, testing, validation = tf.split([.8, .1, .1])

results = pipeline(
    training=training,
    testing=testing,
    validation=validation,
    model= "TransE"
    ,epochs= 2
    ,dimensions= 128
    ,random_seed=42,
    # device='mps', ## apple silicon - crashes OOM ; but works with cpu? depends on size, embed size
    # device="cpu", # "cuda" # runs stably on cuda
    evaluation_fallback = True,
    stopper='early', 
    training_loop_kwargs=dict(automatic_memory_optimization=False),
    evaluator_kwargs=dict(automatic_memory_optimization=False,batch_size=32), # batch_size
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants