Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When memory pressure is high, THCStorage.cu resize() algo use device2device copy , will cause out of memory crash. #72

Closed
smartbitcoin opened this issue Mar 10, 2015 · 5 comments

Comments

@smartbitcoin
Copy link

the code logic was like:

float *data;
THCudaCheck(cudaMalloc((void**)(&data), size * sizeof(float)));
THCudaCheck(cudaMemcpyAsync(data, self->data, THMin(self->size, size) * sizeof(float), cudaMemcpyDeviceToDevice));
THCudaCheck(cudaFree(self->data));

when considering scenario:
GPU have 4G RAM, but 3G used, if lua call resize() the storage from 3G to 3.5G, above code will crash by out of memory, but actually the GPU still have 1G spare ram.

the optimized logic should be like:
if device ram is not enough to malloc, first , copy current data from device to host, then release device ram , after that malloc new device ram , finally copy by the content from host to device.

Please consider this request's importance b/c device ram always very tight. it should be better if release ahead of malloc.

@smartbitcoin
Copy link
Author

cudaalloc

@soumith soumith changed the title THCStorage.cu resize() algo use device2device copy , will cause out of memory crash. When memory pressure is high, THCStorage.cu resize() algo use device2device copy , will cause out of memory crash. Mar 10, 2015
@soumith
Copy link
Member

soumith commented Mar 10, 2015

when memory pressure is high, you should do that explicitly on your own. I don't think it is fair to expect that cutorch do a particular operation on CPU implicitly in the background, as this can have many performance side effects that people generally would not expect.

@soumith soumith closed this as completed Mar 10, 2015
@smartbitcoin
Copy link
Author

soumith, device memory need better management especially when CUDA itself still not that smart there. I put a scenario here. you have 4G device ram, you alloc 1.5G first, later on you want to resize to 2.5G. in this case, the resize() call still possible "out of memory" crash if the first 1.5G alloc not align to memory boundary, then there are leaking memory in middle of whole device ram, which hold CUDA alloc continous 2.5G ram. ( but there still enough available ram there ).

resize() only called few times during whole training process, but it's the main reason cause crashing.
swap out, free, then swap in will be a good algo ( alloc small trunk instead of huge amount will be excellent one, but hard to implements. ) for the case when there do have enough RAM, but resize() still failed. it only have tiny performance impact , but it's a "life save" changes.

I did my testing, now the issue is not the performance impact, it's the alloc and free is controlled by cuda runtime. so even you free the "old" content before resize(), those memory space still not return to cuda runtime immediatly , I'll try to figure out how to do a sucess "swap" lol.

@deepakjnath
Copy link

@smartbitcoin where you able to find a solution for this problem? I am encountering the same issue. I find it to be a major bottleneck

@smartbitcoin
Copy link
Author

Kind of. I switch to Caffe, which Blob structure can let you control GRam flexible, but you need write some c++ code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants