From 90a21b5d6ad409709fcf0eefaaebfc1ecd9fe1dd Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Sat, 9 Jun 2018 19:45:31 +0200 Subject: [PATCH] fixes --- diracnet.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/diracnet.py b/diracnet.py index fdb669e..5a32e5a 100644 --- a/diracnet.py +++ b/diracnet.py @@ -32,7 +32,7 @@ def bnparams(n): def data_parallel(f, input, params, mode, device_ids, output_device=None): - assert isinstance(device_ids, list) + device_ids = list(device_ids) if output_device is None: output_device = device_ids[0] @@ -99,8 +99,8 @@ def define_diracnet(depth, width, dataset): def gen_group_params(ni, no, count): return {'block%d' % i: {'conv': conv_params(ni if i == 0 else no, no, k=3), - 'alpha': cast(torch.ones(no).fill_(1)), - 'beta': cast(torch.ones(no).fill_(0.1)), + 'alpha': torch.ones(no).fill_(1), + 'beta': torch.ones(no).fill_(0.1), 'bn': bnparams(no)} for i in range(count)} if dataset.startswith('CIFAR'): @@ -120,7 +120,7 @@ def f(inputs, params, mode): return o params = { - 'conv': cast(kaiming_normal_(torch.Tensor(widths[0], 3, 3, 3))), + 'conv': kaiming_normal_(torch.Tensor(widths[0], 3, 3, 3)), 'bn': bnparams(widths[0]), 'group0': gen_group_params(widths[0], widths[0], n * 2), 'group1': gen_group_params(widths[0], widths[1], n * 2), @@ -150,7 +150,7 @@ def f(inputs, params, mode): return o params = { - 'conv': cast(kaiming_normal_(torch.Tensor(widths[0], 3, 7, 7))), + 'conv': kaiming_normal_(torch.Tensor(widths[0], 3, 7, 7)), 'group0': gen_group_params(widths[0], widths[0], 2 * blocks[0]), 'group1': gen_group_params(widths[0], widths[1], 2 * blocks[1]), 'group2': gen_group_params(widths[1], widths[2], 2 * blocks[2]), @@ -163,11 +163,13 @@ def f(inputs, params, mode): flat_params = flatten(params) + flat_params = {k: cast(v.data) for k, v in flat_params.items()} + set_requires_grad_except_bn_(flat_params) for k, v in list(flat_params.items()): if k.find('.conv') > -1: - flat_params[size2name(v.size())] = cast(dirac_(v.data.clone())) + flat_params[size2name(v.size())] = dirac_(v.data.clone()) return f, flat_params