Multi-GPU support #42

Closed
soumith opened this Issue Sep 23, 2014 · 41 comments

Comments

Projects
None yet
9 participants
@soumith
Member

soumith commented Sep 23, 2014

MultiGPU support has been implemented in cutorch (and by extension all torch cuda libraries like cunn, cudnn etc.).

  • Switch the device on the fly with cutorch.setDevice(devID)
  • All cuda calls are asynchronous, and can be synchronized with cutorch.synchronize()

Example usage for tensors:

-- Let us do matrix addition for matrices sitting on two different GPUs
cutorch.setDevice(1)
matrix1 = torch.CudaTensor(10):fill(1)
print(matrix1) -- printing is a synchronous call, so you dont have to explicitly call cutorch.synchronize()
cutorch.setDevice(2)
matrix2 = torch.CudaTensor(10):fill(2)
print(matrix2) 
matrix2:add(matrix1) -- matrix1 is seamlessly copied onto GPU2 and added to matrix2
print(matrix2)

if you want to do data-parallel training of neural nets (including convnets), your training loop can run like this:

For each mini-batch:

1. load data (preferably using multiple threads, for example using [threads-ffi](https://github.com/torch/threads-ffi))
2. loop over GPUs (the loop below will be completely anynchronous, so will run parallely)
  2.1. model[gpuX]:forward
  2.2. criterion[gpuX]:forward
  2.3. criterion[gpuX]:backward
  2.4. model[gpuX]:backward
3. cutorch.synchronize()
4. accumulate GPUx's gradParameters to GPU1's gradParameters
5. do SGD on GPU1
6. copy back GPU1's parameters to GPUx
7. cutorch.synchronize() and print accuracy etc.

Loop back to 1 for next mini-batch

Also, to train ConvNets using multiple GPUs, I recommend using CuDNN for the convolution layers, as I've tested that they are completely asynchronous (meaning that the processing runs parallely on multiple GPUs)

Comments below describe the technical details of changes made. If you just want to use Multi-GPU, you can stop reading now.

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 23, 2014

Member

Comments below describe the technical details of changes made. If you just want to use Multi-GPU, you can stop reading now.
What is missing?

cutorch.setDevice right now resets the random seed as well. This needs to be separated so that you can use multiple GPUs as needed
https://github.com/torch/cutorch/blob/master/init.c#L42

GPU-to-GPU copy now has to use a host-bridge. This needs to be changed to P2P GPU copy. This is really trivial to implement, in the cutorch initialization function, we just have to enable p2p for each GPU detected with this function:
http://developer.download.nvidia.com/compute/cuda/4_1/rel/toolkit/docs/online/group__CUDART__PEER_g9e5ea65a18938c2b8715a5602105c306.html

After that UVA takes care of everything else, copying tensors from one GPU to another is as simple as this:
cutorch.setDevice(1)
t1 = torch.randn(100):cuda()
cutorch.setDevice(2)
t2 = torch.randn(100):cuda()
-- UVA copy
t2:copy(t1)

Internally, Clement and us have multi-GPU support, and we will get the changes back to cutorch slowly (it will take time to isolate the commits and get approval etc), but if you are really adventurous, this is a couple of hours of work.

Member

soumith commented Sep 23, 2014

Comments below describe the technical details of changes made. If you just want to use Multi-GPU, you can stop reading now.
What is missing?

cutorch.setDevice right now resets the random seed as well. This needs to be separated so that you can use multiple GPUs as needed
https://github.com/torch/cutorch/blob/master/init.c#L42

GPU-to-GPU copy now has to use a host-bridge. This needs to be changed to P2P GPU copy. This is really trivial to implement, in the cutorch initialization function, we just have to enable p2p for each GPU detected with this function:
http://developer.download.nvidia.com/compute/cuda/4_1/rel/toolkit/docs/online/group__CUDART__PEER_g9e5ea65a18938c2b8715a5602105c306.html

After that UVA takes care of everything else, copying tensors from one GPU to another is as simple as this:
cutorch.setDevice(1)
t1 = torch.randn(100):cuda()
cutorch.setDevice(2)
t2 = torch.randn(100):cuda()
-- UVA copy
t2:copy(t1)

Internally, Clement and us have multi-GPU support, and we will get the changes back to cutorch slowly (it will take time to isolate the commits and get approval etc), but if you are really adventurous, this is a couple of hours of work.

@soumith

This comment has been minimized.

Show comment
Hide comment
@nicholas-leonard

This comment has been minimized.

Show comment
Hide comment
@nicholas-leonard

nicholas-leonard Sep 23, 2014

Member

And I am guessing we should use this for our D2D memory copies: http://developer.download.nvidia.com/compute/cuda/4_1/rel/toolkit/docs/online/group__CUDART__MEMORY_g046702971bc5a66d9bc6000682a6d844.html#g046702971bc5a66d9bc6000682a6d844

This means that if I have kernel sequence A->B->C (one device 1), followed by device2device memcopy D , followed by kernels (one device 2) E->F->G, then eventually A->B->C of iteration t should run in parallel to E->F->G of the previous iteration (t-1), right? Otherwise, I don't see how this can be useful, other than allowing the use of more GPU memory. I mean ideally, you want those GPUs to work on different kernels concurrently. Say A->B->C are the first 3 modules of nn.Sequential, and E->F->G are the last 3.

Member

nicholas-leonard commented Sep 23, 2014

And I am guessing we should use this for our D2D memory copies: http://developer.download.nvidia.com/compute/cuda/4_1/rel/toolkit/docs/online/group__CUDART__MEMORY_g046702971bc5a66d9bc6000682a6d844.html#g046702971bc5a66d9bc6000682a6d844

This means that if I have kernel sequence A->B->C (one device 1), followed by device2device memcopy D , followed by kernels (one device 2) E->F->G, then eventually A->B->C of iteration t should run in parallel to E->F->G of the previous iteration (t-1), right? Otherwise, I don't see how this can be useful, other than allowing the use of more GPU memory. I mean ideally, you want those GPUs to work on different kernels concurrently. Say A->B->C are the first 3 modules of nn.Sequential, and E->F->G are the last 3.

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 23, 2014

Member

@nicholas-leonard you dont need to do cudaMemcpyPeer explicitly anymore, UVA takes care of it.

Member

soumith commented Sep 23, 2014

@nicholas-leonard you dont need to do cudaMemcpyPeer explicitly anymore, UVA takes care of it.

@nicholas-leonard

This comment has been minimized.

Show comment
Hide comment
@nicholas-leonard

nicholas-leonard Sep 23, 2014

Member

Wow. So you are right, it would be super easy to implement. We check for device UVA flag, then call cudaDeviceEnablePeerAccess for every combination of such devices. Easy.

Member

nicholas-leonard commented Sep 23, 2014

Wow. So you are right, it would be super easy to implement. We check for device UVA flag, then call cudaDeviceEnablePeerAccess for every combination of such devices. Easy.

@nicholas-leonard

This comment has been minimized.

Show comment
Hide comment
@nicholas-leonard

nicholas-leonard Sep 23, 2014

Member

Still, would the two sequences in the above example be able to run concurrently?

Member

nicholas-leonard commented Sep 23, 2014

Still, would the two sequences in the above example be able to run concurrently?

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 23, 2014

Member

yes they would run concurrently starting the next itreration as long as you have no blocking calls anywhere.
clement already removed all the blocking calls from the convnet pipeline in a few commits earlier this year (dont remember the exact hashes).

Member

soumith commented Sep 23, 2014

yes they would run concurrently starting the next itreration as long as you have no blocking calls anywhere.
clement already removed all the blocking calls from the convnet pipeline in a few commits earlier this year (dont remember the exact hashes).

@szagoruyko

This comment has been minimized.

Show comment
Hide comment
@szagoruyko

szagoruyko Sep 24, 2014

Member

@soumith I did the UVA init for GPUs and got rid of "cuda runtime error : an illegal memory access was encountered" errors while copying directly from one tensor to another, I'm not sure, however, that network calls are not blocking everywhere. How do we test it?

Member

szagoruyko commented Sep 24, 2014

@soumith I did the UVA init for GPUs and got rid of "cuda runtime error : an illegal memory access was encountered" errors while copying directly from one tensor to another, I'm not sure, however, that network calls are not blocking everywhere. How do we test it?

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 24, 2014

Member

@szagoruyko profile it (with sys.tic()/sys.toc() or just with os.clock().
if model:forward() or model:backward() takes more then say 5-10ms, there is blocking calls.
you should only see a big timing number once you do cutorch.synchronize(), and until then all the cuda calls should not take any time at all.

Member

soumith commented Sep 24, 2014

@szagoruyko profile it (with sys.tic()/sys.toc() or just with os.clock().
if model:forward() or model:backward() takes more then say 5-10ms, there is blocking calls.
you should only see a big timing number once you do cutorch.synchronize(), and until then all the cuda calls should not take any time at all.

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 24, 2014

Member

@szagoruyko SpatialConvolutionMM seems to be randomly blocking sometimes, due to for-looped cublas calls. Use CuDNN, there is no blocking at all.

Member

soumith commented Sep 24, 2014

@szagoruyko SpatialConvolutionMM seems to be randomly blocking sometimes, due to for-looped cublas calls. Use CuDNN, there is no blocking at all.

@clementfarabet

This comment has been minimized.

Show comment
Hide comment
@clementfarabet

clementfarabet Sep 24, 2014

Member

Yeah, this is really annoying, I could never figure out why this gemm calls
are blocking sometimes, there's got to be a reason...

On Wed, Sep 24, 2014 at 12:26 PM, Soumith Chintala <notifications@github.com

wrote:

@szagoruyko https://github.com/szagoruyko SpatialConvolutionMM seems to
be randomly blocking sometimes, due to for-looped cublas calls. Use CuDNN,
there is no blocking at all.


Reply to this email directly or view it on GitHub
#42 (comment).

Member

clementfarabet commented Sep 24, 2014

Yeah, this is really annoying, I could never figure out why this gemm calls
are blocking sometimes, there's got to be a reason...

On Wed, Sep 24, 2014 at 12:26 PM, Soumith Chintala <notifications@github.com

wrote:

@szagoruyko https://github.com/szagoruyko SpatialConvolutionMM seems to
be randomly blocking sometimes, due to for-looped cublas calls. Use CuDNN,
there is no blocking at all.


Reply to this email directly or view it on GitHub
#42 (comment).

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 24, 2014

Member

@clementfarabet I've even tried moving to cublasv2, and that didnt help. Maybe CuBLAS has a queue that gets filled? It's only conjecture, as we dont have source code.

Member

soumith commented Sep 24, 2014

@clementfarabet I've even tried moving to cublasv2, and that didnt help. Maybe CuBLAS has a queue that gets filled? It's only conjecture, as we dont have source code.

@clementfarabet

This comment has been minimized.

Show comment
Hide comment
@clementfarabet

clementfarabet Sep 24, 2014

Member

It probably does, in which case we would need to use streams

Clément

On Sep 24, 2014, at 12:30 PM, Soumith Chintala notifications@github.com wrote:

@clementfarabet I've even tried moving to cublasv2, and that didnt help. Maybe CuBLAS has a queue that gets filled? It's only conjecture, as we dont have source code.


Reply to this email directly or view it on GitHub.

Member

clementfarabet commented Sep 24, 2014

It probably does, in which case we would need to use streams

Clément

On Sep 24, 2014, at 12:30 PM, Soumith Chintala notifications@github.com wrote:

@clementfarabet I've even tried moving to cublasv2, and that didnt help. Maybe CuBLAS has a queue that gets filled? It's only conjecture, as we dont have source code.


Reply to this email directly or view it on GitHub.

@szagoruyko

This comment has been minimized.

Show comment
Hide comment
@szagoruyko

szagoruyko Sep 24, 2014

Member

@soumith forward cudnn itself is not blocking, but when I add nn.Reshape and nn.Linear it blocks. backward is not blocking at all though

Member

szagoruyko commented Sep 24, 2014

@soumith forward cudnn itself is not blocking, but when I add nn.Reshape and nn.Linear it blocks. backward is not blocking at all though

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 24, 2014

Member

@szagoruyko use nn.View
nn.Linear shouldn't be blocking I think. If it still is, let me know, I'll take a look.

Member

soumith commented Sep 24, 2014

@szagoruyko use nn.View
nn.Linear shouldn't be blocking I think. If it still is, let me know, I'll take a look.

@nicholas-leonard

This comment has been minimized.

Show comment
Hide comment
@nicholas-leonard

nicholas-leonard Sep 24, 2014

Member

I think the call to new():fill() blocks here: https://github.com/torch/nn/blob/master/Linear.lua#L48

Member

nicholas-leonard commented Sep 24, 2014

I think the call to new():fill() blocks here: https://github.com/torch/nn/blob/master/Linear.lua#L48

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 24, 2014

Member

i dont know anymore what the public cutorch is like. it is possible that that line is blocking, that line is not needed there, it can be a temporary buffer that is reused.

Member

soumith commented Sep 24, 2014

i dont know anymore what the public cutorch is like. it is possible that that line is blocking, that line is not needed there, it can be a temporary buffer that is reused.

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 24, 2014

Member

if not self.addBuffer or (self.addBuffer:size(1) ~= nframe) then
self.addBuffer = input.new(nframe):fill(1)
end
self.output:zero():addr(1, self.addBuffer, self.bias)

shall i patch it, or does someone else want to do the honours? Same with line 89 and 92, same addBuffer can be reused.

Member

soumith commented Sep 24, 2014

if not self.addBuffer or (self.addBuffer:size(1) ~= nframe) then
self.addBuffer = input.new(nframe):fill(1)
end
self.output:zero():addr(1, self.addBuffer, self.bias)

shall i patch it, or does someone else want to do the honours? Same with line 89 and 92, same addBuffer can be reused.

@szagoruyko

This comment has been minimized.

Show comment
Hide comment
@szagoruyko

szagoruyko Sep 24, 2014

Member

@soumith @nicholas-leonard it doesn't go there actually, nunit is 1 in my case

Member

szagoruyko commented Sep 24, 2014

@soumith @nicholas-leonard it doesn't go there actually, nunit is 1 in my case

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 24, 2014

Member

ah, for nunit=1, this line is blocking
https://github.com/torch/nn/blob/master/Linear.lua#L45
because it gets the bias back to host memory. I'll patch it sometime today.

Member

soumith commented Sep 24, 2014

ah, for nunit=1, this line is blocking
https://github.com/torch/nn/blob/master/Linear.lua#L45
because it gets the bias back to host memory. I'll patch it sometime today.

@szagoruyko

This comment has been minimized.

Show comment
Hide comment
@szagoruyko

szagoruyko Sep 24, 2014

Member

@soumith cool! ccn2 is blocking by the way. MM and ccn2 are blocking, and ccn2 in backward is not fully blocked. Should I share the test script somewhere?

Member

szagoruyko commented Sep 24, 2014

@soumith cool! ccn2 is blocking by the way. MM and ccn2 are blocking, and ccn2 in backward is not fully blocked. Should I share the test script somewhere?

@nicholas-leonard

This comment has been minimized.

Show comment
Hide comment
Member

nicholas-leonard commented Sep 24, 2014

@soumith thanks

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 24, 2014

Member

@szagoruyko yes that would be helpful for all

Member

soumith commented Sep 24, 2014

@szagoruyko yes that would be helpful for all

@szagoruyko

This comment has been minimized.

Show comment
Hide comment
@szagoruyko

szagoruyko Sep 24, 2014

Member

and pull request is here #44

Member

szagoruyko commented Sep 24, 2014

and pull request is here #44

@szagoruyko

This comment has been minimized.

Show comment
Hide comment
@szagoruyko

szagoruyko Sep 24, 2014

Member

by the way, is it possible to have shared modules on different GPUs?

Member

szagoruyko commented Sep 24, 2014

by the way, is it possible to have shared modules on different GPUs?

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 24, 2014

Member

@szagoruyko just added p2p access, anyone else wants to take the task of not resetting the random seed every time setDevice is called? All you have to do is move the randomseed initialization to cuda initialization (per device)

Member

soumith commented Sep 24, 2014

@szagoruyko just added p2p access, anyone else wants to take the task of not resetting the random seed every time setDevice is called? All you have to do is move the randomseed initialization to cuda initialization (per device)

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 24, 2014

Member

@szagoruyko not directly, but if you want to do data-parallel training, your training loop can run like this:

  • loop over GPUs
    • run model[gpuX]:forward + criterion[gpuX]:forward + criterion[gpuX]:backward + model[gpuX]:backward
  • accumulate GPUx's gradParameters to GPU1's gradParameters
  • do SGD on GPU1
  • copy back GPU1's parameters to GPUx
Member

soumith commented Sep 24, 2014

@szagoruyko not directly, but if you want to do data-parallel training, your training loop can run like this:

  • loop over GPUs
    • run model[gpuX]:forward + criterion[gpuX]:forward + criterion[gpuX]:backward + model[gpuX]:backward
  • accumulate GPUx's gradParameters to GPU1's gradParameters
  • do SGD on GPU1
  • copy back GPU1's parameters to GPUx
@szagoruyko

This comment has been minimized.

Show comment
Hide comment
@szagoruyko

szagoruyko Sep 24, 2014

Member

@soumith cool, looks like we can do it efficiently now. Thanks!

Member

szagoruyko commented Sep 24, 2014

@soumith cool, looks like we can do it efficiently now. Thanks!

@dominikgrewe

This comment has been minimized.

Show comment
Hide comment
@dominikgrewe

dominikgrewe Sep 25, 2014

Member

I've created a pull request for moving the random seed initialization: #45

Member

dominikgrewe commented Sep 25, 2014

I've created a pull request for moving the random seed initialization: #45

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Sep 25, 2014

Member

Awesome! now that this is done, basic Multi-GPU support is essentially done.
Anything on top of this is going to be a feature.
Helper methods like getting the device associated with a Tensor/Storage, removing any remaining blocking calls in the entire cutorch/cunn, etc.

So, @jonathantompson to answer your earlier question, torch has Multi-GPU support ;)

Member

soumith commented Sep 25, 2014

Awesome! now that this is done, basic Multi-GPU support is essentially done.
Anything on top of this is going to be a feature.
Helper methods like getting the device associated with a Tensor/Storage, removing any remaining blocking calls in the entire cutorch/cunn, etc.

So, @jonathantompson to answer your earlier question, torch has Multi-GPU support ;)

@soumith soumith closed this Sep 25, 2014

@szagoruyko szagoruyko referenced this issue in torch/nn Sep 30, 2014

Closed

CUDA blocking call in Linear #77

@lukeyeager lukeyeager referenced this issue in NVIDIA/DIGITS Jun 9, 2015

Closed

Multi-GPU in Torch #138

@zhangzibin zhangzibin referenced this issue in karpathy/char-rnn Jul 7, 2015

Closed

Mutil-GPU support? #57

@sherjilozair sherjilozair referenced this issue in karpathy/char-rnn Aug 7, 2015

Closed

Multi-GPU support #77

@algred

This comment has been minimized.

Show comment
Hide comment
@algred

algred Aug 13, 2015

Error happens when I run the example: it seems the math operations are not "seamless" as declared:

th> cutorch.setDevice(1)
th> matrix1 = torch.CudaTensor(10):fill(1)
th> print(matrix1)
cutorch.synchronize()
1
1
1
1
1
1
1
1
1
1
[torch.CudaTensor of size 10]
th> cutorch.setDevice(2)
th> matrix2 = torch.CudaTensor(10):fill(2)
th> print(matrix2)
2
2
2
2
2
2
2
2
2
2
[torch.CudaTensor of size 10]
th> matrix2:add(matrix1)
[string "matrix2:add(matrix1) -- matrix1 is seamlessly..."]:1: Assertion `THCudaTensor_checkGPU(state, 3, self_, src1, src2)' failed. at /tmp/luarocks_cutorch-scm-1-6658/cutorch/lib/THC/THCTensorMathPointwise.cu:82
stack traceback:
[C]: in function 'add'
[string "matrix2:add(matrix1) -- matrix1 is seamlessly..."]:1: in main chunk
[C]: in function 'xpcall'
/home/shugao/torch/install/share/lua/5.1/trepl/init.lua:648: in function 'repl'
...ugao/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:185: in main chunk
[C]: at 0x00406670

algred commented Aug 13, 2015

Error happens when I run the example: it seems the math operations are not "seamless" as declared:

th> cutorch.setDevice(1)
th> matrix1 = torch.CudaTensor(10):fill(1)
th> print(matrix1)
cutorch.synchronize()
1
1
1
1
1
1
1
1
1
1
[torch.CudaTensor of size 10]
th> cutorch.setDevice(2)
th> matrix2 = torch.CudaTensor(10):fill(2)
th> print(matrix2)
2
2
2
2
2
2
2
2
2
2
[torch.CudaTensor of size 10]
th> matrix2:add(matrix1)
[string "matrix2:add(matrix1) -- matrix1 is seamlessly..."]:1: Assertion `THCudaTensor_checkGPU(state, 3, self_, src1, src2)' failed. at /tmp/luarocks_cutorch-scm-1-6658/cutorch/lib/THC/THCTensorMathPointwise.cu:82
stack traceback:
[C]: in function 'add'
[string "matrix2:add(matrix1) -- matrix1 is seamlessly..."]:1: in main chunk
[C]: in function 'xpcall'
/home/shugao/torch/install/share/lua/5.1/trepl/init.lua:648: in function 'repl'
...ugao/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:185: in main chunk
[C]: at 0x00406670

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Aug 13, 2015

Member

@algred the only operation that is allowed cross-GPU is the copy operation. All other operations are checked with assertions to make sure what you are doing is not possible. To get good performance, you'll have to copy matrix1 onto matrix2's GPU and then do the mathematical operation.

If you dont like this setting, you can simply disable these assertions by adding the define DISABLE_CHECK_GPU and reinstalling cutorch.

https://github.com/torch/cutorch/blob/master/lib/THC/THCTensor.c#L761

Member

soumith commented Aug 13, 2015

@algred the only operation that is allowed cross-GPU is the copy operation. All other operations are checked with assertions to make sure what you are doing is not possible. To get good performance, you'll have to copy matrix1 onto matrix2's GPU and then do the mathematical operation.

If you dont like this setting, you can simply disable these assertions by adding the define DISABLE_CHECK_GPU and reinstalling cutorch.

https://github.com/torch/cutorch/blob/master/lib/THC/THCTensor.c#L761

@algred

This comment has been minimized.

Show comment
Hide comment
@algred

algred Aug 13, 2015

@soumith Thank you very much for replying!

algred commented Aug 13, 2015

@soumith Thank you very much for replying!

@darksigma

This comment has been minimized.

Show comment
Hide comment
@darksigma

darksigma Oct 10, 2015

Does this PR support GPU Direct RDMA? Or are additional lower-level modifications necessary to run on a multi GPU Mellanox/GTX Titan X cluster?

Does this PR support GPU Direct RDMA? Or are additional lower-level modifications necessary to run on a multi GPU Mellanox/GTX Titan X cluster?

@eriche2016

This comment has been minimized.

Show comment
Hide comment
@eriche2016

eriche2016 Jan 16, 2016

If i do data parallel training, how does the minibatch data are split over the multi GPUs, do they split evenly by the scheduler, or just split the minibatch manually

If i do data parallel training, how does the minibatch data are split over the multi GPUs, do they split evenly by the scheduler, or just split the minibatch manually

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Jan 26, 2016

Member

evenly

Member

soumith commented Jan 26, 2016

evenly

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Jan 26, 2016

Member

@darksigma try looking at nccl.torch for that.

Member

soumith commented Jan 26, 2016

@darksigma try looking at nccl.torch for that.

@byronwwang

This comment has been minimized.

Show comment
Hide comment
@byronwwang

byronwwang May 19, 2016

@soumith For mini-batch splitting, is there any shuffle before it? Or just split the batch evenly according to the original order of samples in the batch.

@soumith For mini-batch splitting, is there any shuffle before it? Or just split the batch evenly according to the original order of samples in the batch.

@soumith

This comment has been minimized.

Show comment
Hide comment
@soumith

soumith Jun 4, 2016

Member

No shuffle, split evenly

Member

soumith commented Jun 4, 2016

No shuffle, split evenly

@VincentSC VincentSC referenced this issue in torch/cunn Apr 23, 2018

Open

use memory of GPU for a process #466

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment