Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
52 lines (43 sloc) 1.3 KB
local utils = {}
function utils.MSRinit(model)
for k,v in pairs(model:findModules('nn.SpatialConvolution')) do
local n = v.kW*v.kH*v.nInputPlane
v.weight:normal(0,math.sqrt(2/n))
if v.bias then v.bias:zero() end
end
end
function utils.FCinit(model)
for k,v in pairs(model:findModules'nn.Linear') do
v.bias:zero()
end
end
function utils.DisableBias(model)
for i,v in ipairs(model:findModules'nn.SpatialConvolution') do
v.bias = nil
v.gradBias = nil
end
end
function utils.testModel(model)
model:float()
local imageSize = opt and opt.imageSize or 32
local input = torch.randn(1,3,imageSize,imageSize):type(model._type)
print('forward output',{model:forward(input)})
print('backward output',{model:backward(input,model.output)})
model:reset()
end
function utils.makeDataParallelTable(model, nGPU)
if nGPU > 1 then
local gpus = torch.range(1, nGPU):totable()
local fastest, benchmark = cudnn.fastest, cudnn.benchmark
local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
local cudnn = require 'cudnn'
cudnn.fastest, cudnn.benchmark = fastest, benchmark
end)
dpt.gradInput = nil
model = dpt:cuda()
end
return model
end
return utils