Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
szagoruyko committed Jun 9, 2018
1 parent 3c6b094 commit 90a21b5
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions diracnet.py
Expand Up @@ -32,7 +32,7 @@ def bnparams(n):


def data_parallel(f, input, params, mode, device_ids, output_device=None):
assert isinstance(device_ids, list)
device_ids = list(device_ids)
if output_device is None:
output_device = device_ids[0]

Expand Down Expand Up @@ -99,8 +99,8 @@ def define_diracnet(depth, width, dataset):

def gen_group_params(ni, no, count):
return {'block%d' % i: {'conv': conv_params(ni if i == 0 else no, no, k=3),
'alpha': cast(torch.ones(no).fill_(1)),
'beta': cast(torch.ones(no).fill_(0.1)),
'alpha': torch.ones(no).fill_(1),
'beta': torch.ones(no).fill_(0.1),
'bn': bnparams(no)} for i in range(count)}

if dataset.startswith('CIFAR'):
Expand All @@ -120,7 +120,7 @@ def f(inputs, params, mode):
return o

params = {
'conv': cast(kaiming_normal_(torch.Tensor(widths[0], 3, 3, 3))),
'conv': kaiming_normal_(torch.Tensor(widths[0], 3, 3, 3)),
'bn': bnparams(widths[0]),
'group0': gen_group_params(widths[0], widths[0], n * 2),
'group1': gen_group_params(widths[0], widths[1], n * 2),
Expand Down Expand Up @@ -150,7 +150,7 @@ def f(inputs, params, mode):
return o

params = {
'conv': cast(kaiming_normal_(torch.Tensor(widths[0], 3, 7, 7))),
'conv': kaiming_normal_(torch.Tensor(widths[0], 3, 7, 7)),
'group0': gen_group_params(widths[0], widths[0], 2 * blocks[0]),
'group1': gen_group_params(widths[0], widths[1], 2 * blocks[1]),
'group2': gen_group_params(widths[1], widths[2], 2 * blocks[2]),
Expand All @@ -163,11 +163,13 @@ def f(inputs, params, mode):

flat_params = flatten(params)

flat_params = {k: cast(v.data) for k, v in flat_params.items()}

set_requires_grad_except_bn_(flat_params)

for k, v in list(flat_params.items()):
if k.find('.conv') > -1:
flat_params[size2name(v.size())] = cast(dirac_(v.data.clone()))
flat_params[size2name(v.size())] = dirac_(v.data.clone())

return f, flat_params

Expand Down

0 comments on commit 90a21b5

Please sign in to comment.