Skip to content

Commit

Permalink
Updated with weight initialization warnings (#2170)
Browse files Browse the repository at this point in the history
  • Loading branch information
bisakhmondal committed May 5, 2020
1 parent 6f849df commit e1a3042
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 6 additions & 1 deletion torchvision/models/googlenet.py
Expand Up @@ -62,11 +62,16 @@ def googlenet(pretrained=False, progress=True, **kwargs):
class GoogLeNet(nn.Module):
__constants__ = ['aux_logits', 'transform_input']

def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True,
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=None,
blocks=None):
super(GoogLeNet, self).__init__()
if blocks is None:
blocks = [BasicConv2d, Inception, InceptionAux]
if init_weights is None:
warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of '
'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
init_weights = True
assert len(blocks) == 3
conv_block = blocks[0]
inception_block = blocks[1]
Expand Down
7 changes: 6 additions & 1 deletion torchvision/models/inception.py
Expand Up @@ -63,13 +63,18 @@ def inception_v3(pretrained=False, progress=True, **kwargs):
class Inception3(nn.Module):

def __init__(self, num_classes=1000, aux_logits=True, transform_input=False,
inception_blocks=None, init_weights=True):
inception_blocks=None, init_weights=None):
super(Inception3, self).__init__()
if inception_blocks is None:
inception_blocks = [
BasicConv2d, InceptionA, InceptionB, InceptionC,
InceptionD, InceptionE, InceptionAux
]
if init_weights is None:
warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
init_weights = True
assert len(inception_blocks) == 7
conv_block = inception_blocks[0]
inception_a = inception_blocks[1]
Expand Down

0 comments on commit e1a3042

Please sign in to comment.