From e69441117733494f4ec7dac1a1f438bc16cd5700 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 8 May 2022 16:38:34 +0100 Subject: [PATCH] CleanUp DenseNet code --- torchvision/models/densenet.py | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 12a0e645545..2fb60fc68cc 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -34,22 +34,14 @@ def __init__( self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False ) -> None: super().__init__() - self.norm1: nn.BatchNorm2d - self.add_module("norm1", nn.BatchNorm2d(num_input_features)) - self.relu1: nn.ReLU - self.add_module("relu1", nn.ReLU(inplace=True)) - self.conv1: nn.Conv2d - self.add_module( - "conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False) - ) - self.norm2: nn.BatchNorm2d - self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate)) - self.relu2: nn.ReLU - self.add_module("relu2", nn.ReLU(inplace=True)) - self.conv2: nn.Conv2d - self.add_module( - "conv2", nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False) - ) + self.norm1 = nn.BatchNorm2d(num_input_features) + self.relu1 = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False) + + self.norm2 = nn.BatchNorm2d(bn_size * growth_rate) + self.relu2 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False) + self.drop_rate = float(drop_rate) self.memory_efficient = memory_efficient @@ -136,10 +128,10 @@ def forward(self, init_features: Tensor) -> Tensor: class _Transition(nn.Sequential): def __init__(self, num_input_features: int, num_output_features: int) -> None: super().__init__() - self.add_module("norm", nn.BatchNorm2d(num_input_features)) - self.add_module("relu", nn.ReLU(inplace=True)) - self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) - self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2)) + self.norm = nn.BatchNorm2d(num_input_features) + self.relu = nn.ReLU(inplace=True) + self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False) + self.pool = nn.AvgPool2d(kernel_size=2, stride=2) class DenseNet(nn.Module):