From 9f1b6403641632dac8248399bfabca65ca655dca Mon Sep 17 00:00:00 2001 From: Konstantin Lopuhin Date: Fri, 14 Jul 2017 16:57:27 +0300 Subject: [PATCH] pass kwargs to densenet --- torchvision/models/densenet.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 0b32d80ea03..e635132f285 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -22,7 +22,8 @@ def densenet121(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)) + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), + **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['densenet121'])) return model @@ -35,7 +36,8 @@ def densenet169(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32)) + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), + **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['densenet169'])) return model @@ -48,7 +50,8 @@ def densenet201(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32)) + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), + **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['densenet201'])) return model @@ -61,7 +64,8 @@ def densenet161(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24)) + model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), + **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['densenet161'])) return model