-
Notifications
You must be signed in to change notification settings - Fork 441
Description
When trying to train MrVI on batch correction, I face this note "Jax module moved to cuda:0.Note: Pytorch lightning will show GPU is not being used for the Trainer. ". And following that is this:
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/kassab/miniconda3/envs/CB803/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing Trainer(accelerator='gpu').
So I ran watch -n1 nvidia-smi to ensure if its using GPU or not, my GPU shows that its active (GPU power ~65 and memory allocated around 2GB) However, from the training time (~40-60 minutes) on a dataset (AnnData object with n_obs × n_vars = 32826 × 12303) I suspected that maybe something is wrong. On scVI with the same dataset it takes (10-15 minutes).
I have tried:
from lightning.pytorch import Trainer
trainer = Trainer(accelerator="gpu", devices=1, strategy="ddp_notebook")
which shows:
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
I also tried updating Jax to version 0.5.1 but then compatibility issues arised. so I have reused jax==0.4.35
my system info:
python == 3.12
CUDA Version: 12.2
scvi.version == '1.3.0'
Cuda compilation tools, release 12.8, V12.8.93
pytorch.version == 2.6.0+cu124
pl.version == 2.5.0.post0
# Your code hereimport torch
import time
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
model = MRVI(
adata
)
start_time = time.time()
Train the model
model.train(max_epochs=500, accelerator='gpu', devices=1)
torch.cuda.synchronize()
end_time = time.time()
max_memory_reserved = torch.cuda.max_memory_reserved() / (1024 ** 3)
max_memory_allocated = torch.cuda.max_memory_allocated() / (1024 ** 3)
print(f"Training Time: {end_time - start_time:.2f} seconds")
print(f"Max Reserved GPU Memory: {max_memory_reserved:.2f} GB")
print(f"Max Allocated GPU Memory: {max_memory_allocated:.2f} GB")
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/kassab/miniconda3/envs/CB803/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/home/kassab/miniconda3/envs/CB803/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()Versions:
scvi.version == '1.3.0'