Skip to content

Commit

Permalink
DPT Synchronize() performance fix. GPU flag fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lukacf committed Feb 10, 2016
1 parent c73fb73 commit edcbd57
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
5 changes: 4 additions & 1 deletion train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ function trainBatch(inputsCPU, labelsCPU)
optim.sgd(feval, parameters, optimState)

-- DataParallelTable's syncParameters
model:apply(function(m) if m.syncParameters then m:syncParameters() end end)
if model.needsSync then
model:syncParameters()
end


cutorch.synchronize()
batchNumber = batchNumber + 1
Expand Down
3 changes: 2 additions & 1 deletion util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ function makeDataParallel(model, nGPU)
cutorch.setDevice(i)
model:add(model_single:clone():cuda(), i)
end
cutorch.setDevice(opt.GPU)
end
cutorch.setDevice(opt.GPU)

return model
end

Expand Down

0 comments on commit edcbd57

Please sign in to comment.