Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug][Dataloader] unable to mmap 2048 bytes from file <filename not specified>: Cannot allocate memory (12) #92134

Open
zejun-chen opened this issue Jan 13, 2023 · 22 comments
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler module: multiprocessing Related to torch.multiprocessing triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@zejun-chen
Copy link
Contributor

zejun-chen commented Jan 13, 2023

馃悰 Describe the bug

Hi,

Here is the results i observed when i was running my workload with PyTorch 1.13 on Ubuntu 20.04 to train RN50 with imageNet:
When i run 25 epochs, the error is thrown as below:

Epoch: [25][3096/5005]	Time  0.551 ( 0.548)	Data  0.000 ( 0.001)	Loss 1.9055e+00 (2.1268e+00)	Acc@1  57.81 ( 51.79)	Acc@5  79.69 ( 76.05)
Epoch: [25][3097/5005]	Time  0.547 ( 0.548)	Data  0.000 ( 0.001)	Loss 1.9593e+00 (2.1268e+00)	Acc@1  55.08 ( 51.80)	Acc@5  81.25 ( 76.06)
Exception in thread Thread-52:
Traceback (most recent call last):
  File "/home/gta/miniconda3/envs/zejun/lib/python3.9/threading.py", line 973, in _bootstrap_inner
    self.run()
  File "/home/gta/miniconda3/envs/zejun/lib/python3.9/threading.py", line 910, in run
    self._target(*self._args, **self._kwargs)
  File "/home/gta/miniconda3/envs/zejun/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 52, in _pin_memory_loop
    do_one_step()
  File "/home/gta/miniconda3/envs/zejun/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 29, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/home/gta/miniconda3/envs/zejun/lib/python3.9/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/gta/miniconda3/envs/zejun/lib/python3.9/site-packages/torch/multiprocessing/reductions.py", line 294, in rebuild_storage_fd
    storage = cls._new_shared_fd(fd, size)
RuntimeError: falseINTERNAL ASSERT FAILED at "/home/gta/zejun/pytorch/aten/src/ATen/MapAllocator.cpp":323, please report a bug to PyTorch. unable to mmap 2048 bytes from file <filename not specified>: Cannot allocate memory (12)

This Runtime Error is thrown from torch dataloader. I found the community has had this issue: https://discuss.pytorch.org/t/pytorch-cannot-allocate-memory/134754.

I set the dataloader workers to be 0 and the error is missing, but the training time increases too much when using the single process to fetch and decode the dataset. I also check the host memory usage and no oom happen. The fd limitation of one process is also not exceeded. Thus we wonder if the issue can be fixed, otherwise it may block the model training with PyTorch.

Versions

PyTorch 1.13
Python 3.9
Torchvision 1.14
Running BS: 256
Model: ResNet50
Dataset: ImageNet

Ubuntu 20.04
Host Mem 128G
CPU: Intel Xeon Gold 6342

cc @ssnl @VitalyFedyunin @ejguan @NivekT

@janeyx99 janeyx99 added module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 13, 2023
@ejguan ejguan added the module: multiprocessing Related to torch.multiprocessing label Jan 13, 2023
@ejguan
Copy link
Contributor

ejguan commented Jan 13, 2023

Based on the log, it implies you just ran out of shared memory. It would be good if you can share a minimum script.

@zejun-chen
Copy link
Contributor Author

Based on the log, it implies you just ran out of shared memory. It would be good if you can share a minimum script.

Thank you for your attention. We encounter this error when running the pytorch example for RN50 in the below link:
https://github.com/pytorch/examples/blob/main/imagenet/main.py

We enable this model script on XPU and got this mmap issue. (Cannot allocate the memory)
When i set the dataloader workers to 0, the error is gone.

We are using 64bit Linux machine. The mmap can map files into a excessive virtual memory space with large size in theory. The memory mapping segments could be very large. Is it possible the munmap is demand releasing memory, so the virtual memory mapped by previous iterations are not immediately released, then the memory mapping segments don't have enough size for next iteration dataloader's mapping ? Additionally, this issue needs to run much time to reproduce. It needs about running 50 epochs.

@ejguan
Copy link
Contributor

ejguan commented Jan 19, 2023

If you can't allocate memory, can you please try to use file_system? https://pytorch.org/docs/stable/multiprocessing.html#file-system-file-system

@zejun-chen
Copy link
Contributor Author

zejun-chen commented Jan 30, 2023

@ejguan Thank you for sharing the link.
We checked our failure and monitor the fd resources. It is not caused by exceeding the limitation of the opened fd resources in OS.
From the below screenshot, it looks like mapping the shared memory from other workers into the main process.
image

We found the community also encountered this issue.
https://discuss.pytorch.org/t/pytorch-cannot-allocate-memory/134754.

Thank you.

@ejguan
Copy link
Contributor

ejguan commented Jan 30, 2023

You might want to profile your training script to see when you run out of memory. You might be able to rely on psutil to check your shared memory during each step

@zejun-chen
Copy link
Contributor Author

zejun-chen commented Jan 31, 2023

@ejguan Thank you for your help. I will check the memory footprint with psutil.
I see here is a open issue related to above: #13246
I wonder if the mmap failure is related to the dataloader oom issue

@zejun-chen
Copy link
Contributor Author

zejun-chen commented Feb 1, 2023

Hi, @ejguan

I check the memory usage with psutil. Here is the log:

epoch =  1 . iter =  809
svmem(total=540768874496, available=523621277696, percent=3.2, used=11168075776, free=287612063744, active=145943617536, inactive=90581217280, buffers=1317834752, cached=240670900224, shared=2472620032, slab=11804663808)
sswap(total=0, used=0, free=0, percent=0.0, sin=0, sout=0)

epoch =  1 . iter =  810
svmem(total=540768874496, available=523618697216, percent=3.2, used=11170656256, free=287609483264, active=145943617536, inactive=90583797760, buffers=1317834752, cached=240670900224, shared=2472620032, slab=11804909568)
sswap(total=0, used=0, free=0, percent=0.0, sin=0, sout=0)

epoch =  1 . iter =  811
svmem(total=540768874496, available=523610955776, percent=3.2, used=11178397696, free=287601741824, active=145943617536, inactive=90591539200, buffers=1317834752, cached=240670900224, shared=2472620032, slab=11804909568)
sswap(total=0, used=0, free=0, percent=0.0, sin=0, sout=0)

epoch =  1 . iter =  812
svmem(total=540768874496, available=523675918336, percent=3.2, used=11113103360, free=287665991680, active=145958989824, inactive=90516205568, buffers=1317834752, cached=240671944704, shared=2472976384, slab=11804872704)
sswap(total=0, used=0, free=0, percent=0.0, sin=0, sout=0)

For the shared memory usage, there is no sharply increasing found.

Filesystem                  Size  Used Avail Use% Mounted on
/dev/sda2                   917G  440G  431G  51% /
tmpfs                       48G  2.3G   46G   5% /dev/shm

@zejun-chen
Copy link
Contributor Author

zejun-chen commented Feb 6, 2023

Hi, @ejguan

Here is a linked issue, which is similar with the error message i reported in this kicked off jira.
#65198

I checked the FD the worker consumes for processing my dataset(ImageNet) while no FD exceeding is found.
May i know how the fd is recycled in torch dataloader ?

@zejun-chen
Copy link
Contributor Author

zejun-chen commented Mar 2, 2023

Hi, @ejguan

Sorry to interrupt you. I have some confusion about how the pytorch dataloader works with multi workers.
image

Here the workers will put their result data into the data queue. The worker process will call reduce_tensor and reduce_storage. Then the main process will call rebuild_storage_fd. I wonder why the reduce_tensor and reduce_storage need to be called ?

I track the syscall in the main process, here is a mmap. I wonder what this mmap is called for.
03:00:23.414558 mmap(NULL, 36, PROT_READ|PROT_WRITE, MAP_SHARED, 8, 0) = 0x7ff26bbb4000 <0.000030>

Thank you.

@ejguan
Copy link
Contributor

ejguan commented Mar 2, 2023

I wonder why the reduce_tensor and reduce_storage need to be called ?

We need to move the underlying data to shared memory then the main process would be able to get data from worker process.

@zejun-chen
Copy link
Contributor Author

zejun-chen commented Mar 3, 2023

I wonder why the reduce_tensor and reduce_storage need to be called ?

We need to move the underlying data to shared memory then the main process would be able to get data from worker process.

Thank you for your reply. I see the worker will immediately delete the data when it finishes the processing and put data into the queue. The code is here:

del data, idx, index, r # save memory

Does it mean the data queue will totally copy this data from worker's virtual memory space into the /dev/shm so the worker can safely del data ?
Then the data queue will wait for main process to use mmap to map this memory from /dev/shm into its virtual memory space ?

@ejguan
Copy link
Contributor

ejguan commented Mar 3, 2023

Does it mean the data queue will totally copy this data from worker's virtual memory space into the /dev/shm so the worker can safely del data ?
Then the data queue will wait for main process to use mmap to map this memory from /dev/shm into its virtual memory space ?

Right.

@zejun-chen
Copy link
Contributor Author

Does it mean the data queue will totally copy this data from worker's virtual memory space into the /dev/shm so the worker can safely del data ?
Then the data queue will wait for main process to use mmap to map this memory from /dev/shm into its virtual memory space ?

Right.

Thank you for your explanation. :)

@ASMIftekhar
Copy link

Hi @zejun-chen, did you find a solution to the problem?

@zejun-chen
Copy link
Contributor Author

Hi @zejun-chen, did you find a solution to the problem?

Hi, @ASMIftekhar
We have no root cause but this issue disappeared now after we fixed a host mem leak issue in our project.
If you also met this problem, you can set the OS mmap limitation much higher as a work around. Or set the worker = 0 in dataloader also helps.
Thank you.

@ASMIftekhar
Copy link

Hello,
I am just writing the workable solution for my case. I didnt do a deep analysis. From my dataloader, I was returning images(almost 12K images) and annotations (a list of size ~12K, each element in the list is a python dictionary) for every mini-batch. For 'file_system' sharing strategy, the workers dumped the fetched data in /dev/shm (mentioned above by @zejun-chen) and main process reads them via mmap. The error is coming from this reading. Instead of returning, a list for annotations, I wrapped it in the following way:
return images, pickle.dumps(annotations). Then while reading, I just used (pickle.loads). Error is gone in this way.

I get this idea from: this blog and this issue

@zejun-chen
Copy link
Contributor Author

Hello, @ASMIftekhar
Thank you for sharing. It looks like the read using pickle.dump will not trigger the COW of the annotations(python object). Then the host mem or /dev/shm will not be consumed out, thus the mmap issue(cannot allocate mem) is gone.
Thank you.

@ZhiyuanChen
Copy link
Contributor

same here

@prefer-potato
Copy link

prefer-potato commented May 26, 2023

i meet a same problem. and it is solved when i init 'data_list' in 'getitem()' instead of 'init()' in a Dataset class.

import torch
class my_dataset(torch.utils.data.Dataset):
    def __init__(self,):
        self.data_list=[]
        # other code in __init__()
        
    def init_data_list(self,):
        self.data_list = torch.randn([100000])
        pass
        
    def __getitem__(self,index):
        if self.data_list() == []:
            self.init_data_list()
        return self.data_list[index]

@GoenitzYs
Copy link

GoenitzYs commented Aug 22, 2023

Hello, I met the problem, too. Finally, I found the problem might be in the customized allocation function adopted by the Dataloader (collate_fn=...), which might cause the problem when the return type is "list" or "tuple".

As I utilized DGL library and the data structure is dgl.DGLGraph, I solved the problem by merging the group of the graphs using dg.batch() in the allocation function, and the problem is gone when the allocation functoin returned a tuple of dgl.DGLGraph object rather than a group of graph lists or tuples.

Anyway, anyone who is confronted with the same problem could try this solution, I believe the problem might be caused by not just one reason, but the approach might also be the solution of one of them :)

@YjZhang-sudo
Copy link

Maybe you could try to expand vm.max_map_count.

@ZhiyuanChen
Copy link
Contributor

Hello, I met the problem, too. Finally, I found the problem might be in the customized allocation function adopted by the Dataloader (collate_fn=...), which might cause the problem when the return type is "list" or "tuple".

As I utilized DGL library and the data structure is dgl.DGLGraph, I solved the problem by merging the group of the graphs using dg.batch() in the allocation function, and the problem is gone when the allocation functoin returned a tuple of dgl.DGLGraph object rather than a group of graph lists or tuples.

Anyway, anyone who is confronted with the same problem could try this solution, I believe the problem might be caused by not just one reason, but the approach might also be the solution of one of them :)

I believe this is the key, as we encountered this problem just like you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler module: multiprocessing Related to torch.multiprocessing triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

8 participants