diff --git a/references/video_classification/train.py b/references/video_classification/train.py index f532f121e70..192babf62dc 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -201,8 +201,7 @@ def main(args): pin_memory=True, collate_fn=collate_fn) print("Creating model") - # model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) - model = torchvision.models.video.__dict__[args.model]() + model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) diff --git a/test/test_models.py b/test/test_models.py index 00ea4e65f93..443e9dd236e 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -61,7 +61,7 @@ def _test_detection_model(self, name): def _test_video_model(self, name): # the default input shape is # bs * num_channels * clip_len * h *w - input_shape = (1, 3, 8, 112, 112) + input_shape = (1, 3, 4, 112, 112) # test both basicblock and Bottleneck model = models.video.__dict__[name](num_classes=50) x = torch.rand(input_shape) @@ -145,6 +145,7 @@ def do_test(self, model_name=model_name): setattr(Tester, "test_" + model_name, do_test) + for model_name in get_available_video_models(): def do_test(self, model_name=model_name): diff --git a/torchvision/models/video/__init__.py b/torchvision/models/video/__init__.py index e6e663c50c0..b792ca6ecf7 100644 --- a/torchvision/models/video/__init__.py +++ b/torchvision/models/video/__init__.py @@ -1,3 +1 @@ -from .r3d import * -from .r2plus1d import * -from .mixed_conv import * +from .resnet import * diff --git a/torchvision/models/video/_utils.py b/torchvision/models/video/_utils.py deleted file mode 100644 index 03feb83b8c4..00000000000 --- a/torchvision/models/video/_utils.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch.nn as nn - - -__all__ = ["Conv3DSimple", "Conv2Plus1D", "Conv3DNoTemporal"] - - -class Conv3DSimple(nn.Conv3d): - def __init__(self, - in_planes, - out_planes, - midplanes=None, - stride=1, - padding=1): - - super(Conv3DSimple, self).__init__( - in_channels=in_planes, - out_channels=out_planes, - kernel_size=(3, 3, 3), - stride=stride, - padding=padding, - bias=False) - - @staticmethod - def get_downsample_stride(stride): - return (stride, stride, stride) - - -class Conv2Plus1D(nn.Sequential): - - def __init__(self, - in_planes, - out_planes, - midplanes, - stride=1, - padding=1): - conv1 = [ - nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), - stride=(1, stride, stride), padding=(0, padding, padding), - bias=False), - nn.BatchNorm3d(midplanes), - nn.ReLU(inplace=True), - nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), - stride=(stride, 1, 1), padding=(padding, 0, 0), - bias=False) - ] - super(Conv2Plus1D, self).__init__(*conv1) - - @staticmethod - def get_downsample_stride(stride): - return (stride, stride, stride) - - -class Conv3DNoTemporal(nn.Conv3d): - - def __init__(self, - in_planes, - out_planes, - midplanes=None, - stride=1, - padding=1): - - super(Conv3DNoTemporal, self).__init__( - in_channels=in_planes, - out_channels=out_planes, - kernel_size=(1, 3, 3), - stride=(1, stride, stride), - padding=(0, padding, padding), - bias=False) - - @staticmethod - def get_downsample_stride(stride): - return (1, stride, stride) diff --git a/torchvision/models/video/mixed_conv.py b/torchvision/models/video/mixed_conv.py deleted file mode 100644 index 4b8895ab891..00000000000 --- a/torchvision/models/video/mixed_conv.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch.nn as nn - -from ._utils import Conv3DSimple, Conv3DNoTemporal -from .video_stems import get_default_stem -from .video_trunk import VideoTrunkBuilder, BasicBlock, Bottleneck - - -__all__ = ["mc3_18"] - - -def _mcX(model_depth, X=3, use_pool1=False, **kwargs): - """Generate mixed convolution network as in - https://arxiv.org/abs/1711.11248 - - Args: - model_depth (int): trunk depth - supports most resnet depths - X (int): Up to which layers are convolutions 3D - use_pool1 (bool, optional): Add pooling layer to the stem. Defaults to False. - - Returns: - nn.Module: mcX video trunk - """ - assert X > 1 and X <= 5 - conv_makers = [Conv3DSimple] * (X - 2) - while len(conv_makers) < 5: - conv_makers.append(Conv3DNoTemporal) - - if model_depth < 50: - block = BasicBlock - else: - block = Bottleneck - - model = VideoTrunkBuilder(block=block, conv_makers=conv_makers, model_depth=model_depth, - stem=get_default_stem(use_pool1=use_pool1), **kwargs) - - return model - - -def mc3_18(use_pool1=False, **kwargs): - """Constructor for 18 layer Mixed Convolution network as in - https://arxiv.org/abs/1711.11248 - - Args: - use_pool1 (bool, optional): Include pooling in the resnet stem. Defaults to False. - - Returns: - nn.Module: MC3 Network definitino - """ - return _mcX(18, 3, use_pool1, **kwargs) diff --git a/torchvision/models/video/r2plus1d.py b/torchvision/models/video/r2plus1d.py deleted file mode 100644 index 0ab13419570..00000000000 --- a/torchvision/models/video/r2plus1d.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch.nn as nn - -from ._utils import Conv2Plus1D -from .video_stems import get_r2plus1d_stem -from .video_trunk import VideoTrunkBuilder, BasicBlock, Bottleneck - - -__all__ = ["r2plus1d_18"] - - -def _r2plus1d(model_depth, use_pool1=False, **kwargs): - """Constructor for R(2+1)D network as described in - https://arxiv.org/abs/1711.11248 - - Args: - model_depth (int): Depth of the model - standard resnet depths apply - use_pool1 (bool, optional): Should we use the pooling layer? Defaults to False - Returns: - nn.Module: An R(2+1)D video backbone - """ - convs = [Conv2Plus1D] * 4 - if model_depth < 50: - block = BasicBlock - else: - block = Bottleneck - - model = VideoTrunkBuilder( - block=block, conv_makers=convs, model_depth=model_depth, - stem=get_r2plus1d_stem(use_pool1), **kwargs) - return model - - -def r2plus1d_18(use_pool1=False, **kwargs): - """Constructor for the 18 layer deep R(2+1)D network as in - https://arxiv.org/abs/1711.11248 - - Args: - use_pool1 (bool, optional): Include pooling in the resnet stem. Defaults to False. - - Returns: - nn.Module: R(2+1)D-18 network - """ - return _r2plus1d(18, use_pool1, **kwargs) diff --git a/torchvision/models/video/r3d.py b/torchvision/models/video/r3d.py deleted file mode 100644 index 08fe4f3b219..00000000000 --- a/torchvision/models/video/r3d.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch.nn as nn - -from ._utils import Conv3DSimple -from .video_stems import get_default_stem -from .video_trunk import VideoTrunkBuilder, BasicBlock, Bottleneck - -__all__ = ["r3d_18"] - - -def _r3d(model_depth, use_pool1=False, **kwargs): - """Constructor of a r3d network as in - https://arxiv.org/abs/1711.11248 - - Args: - model_depth (int): resnet trunk depth - use_pool1 (bool, optional): Add pooling layer to the stem. Defaults to False - - Returns: - nn.Module: R3D network trunk - """ - - conv_makers = [Conv3DSimple] * 4 - if model_depth < 50: - block = BasicBlock - else: - block = Bottleneck - - model = VideoTrunkBuilder(block=block, conv_makers=conv_makers, model_depth=model_depth, - stem=get_default_stem(use_pool1=use_pool1), **kwargs) - return model - - -def r3d_18(use_pool1=False, **kwargs): - """Construct 18 layer Resnet3D model as in - https://arxiv.org/abs/1711.11248 - - Args: - use_pool1 (bool, optional): Include pooling in resnet stem. Defaults to False. - - Returns: - nn.Module: R3D-18 network - """ - return _r3d(18, use_pool1, **kwargs) diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py new file mode 100644 index 00000000000..8fe3bd1007f --- /dev/null +++ b/torchvision/models/video/resnet.py @@ -0,0 +1,340 @@ +import torch +import torch.nn as nn + +from ..utils import load_state_dict_from_url + + +__all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18'] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/', + 'resnet34': 'https://download.pytorch.org/models/', +} + + +class Conv3DSimple(nn.Conv3d): + def __init__(self, + in_planes, + out_planes, + midplanes=None, + stride=1, + padding=1): + + super(Conv3DSimple, self).__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(3, 3, 3), + stride=stride, + padding=padding, + bias=False) + + @staticmethod + def get_downsample_stride(stride): + return (stride, stride, stride) + + +class Conv2Plus1D(nn.Sequential): + + def __init__(self, + in_planes, + out_planes, + midplanes, + stride=1, + padding=1): + super(Conv2Plus1D, self).__init__( + nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), + stride=(1, stride, stride), padding=(0, padding, padding), + bias=False), + nn.BatchNorm3d(midplanes), + nn.ReLU(inplace=True), + nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), + stride=(stride, 1, 1), padding=(padding, 0, 0), + bias=False)) + + @staticmethod + def get_downsample_stride(stride): + return (stride, stride, stride) + + +class Conv3DNoTemporal(nn.Conv3d): + + def __init__(self, + in_planes, + out_planes, + midplanes=None, + stride=1, + padding=1): + + super(Conv3DNoTemporal, self).__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False) + + @staticmethod + def get_downsample_stride(stride): + return (1, stride, stride) + + +class BasicBlock(nn.Module): + + expansion = 1 + + def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + super(BasicBlock, self).__init__() + self.conv1 = nn.Sequential( + conv_builder(inplanes, planes, midplanes, stride), + nn.BatchNorm3d(planes), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + conv_builder(planes, planes, midplanes), + nn.BatchNorm3d(planes) + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.conv2(out) + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): + + super(Bottleneck, self).__init__() + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + # 1x1x1 + self.conv1 = nn.Sequential( + nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), + nn.BatchNorm3d(planes), + nn.ReLU(inplace=True) + ) + # Second kernel + self.conv2 = nn.Sequential( + conv_builder(planes, planes, midplanes, stride), + nn.BatchNorm3d(planes), + nn.ReLU(inplace=True) + ) + + # 1x1x1 + self.conv3 = nn.Sequential( + nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), + nn.BatchNorm3d(planes * self.expansion) + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BasicStem(nn.Sequential): + """The default conv-batchnorm-relu stem + """ + def __init__(self): + super(BasicStem, self).__init__( + nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), + padding=(1, 3, 3), bias=False), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True)) + + +class R2Plus1dStem(nn.Sequential): + """R(2+1)D stem is different than the default one as it uses separated 3D convolution + """ + def __init__(self): + super(R2Plus1dStem, self).__init__( + nn.Conv3d(3, 45, kernel_size=(1, 7, 7), + stride=(1, 2, 2), padding=(0, 3, 3), + bias=False), + nn.BatchNorm3d(45), + nn.ReLU(inplace=True), + nn.Conv3d(45, 64, kernel_size=(3, 1, 1), + stride=(1, 1, 1), padding=(1, 0, 0), + bias=False), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True)) + + +class VideoResNet(nn.Module): + + def __init__(self, block, conv_makers, layers, + stem, num_classes=400, + zero_init_residual=False): + """Generic resnet video generator. + + Args: + block (nn.Module): resnet building block + conv_makers (list(functions)): generator function for each layer + layers (List[int]): number of blocks per layer + stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. + num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. + zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. + """ + super(VideoResNet, self).__init__() + self.inplanes = 64 + + self.stem = stem() + + self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) + self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2) + + self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + # init weights + self._initialize_weights() + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + + def forward(self, x): + x = self.stem(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + # Flatten the layer to fc + x = x.flatten(1) + x = self.fc(x) + + return x + + def _make_layer(self, block, conv_builder, planes, blocks, stride=1): + downsample = None + + if stride != 1 or self.inplanes != planes * block.expansion: + ds_stride = conv_builder.get_downsample_stride(stride) + downsample = nn.Sequential( + nn.Conv3d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=ds_stride, bias=False), + nn.BatchNorm3d(planes * block.expansion) + ) + layers = [] + layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, conv_builder)) + + return nn.Sequential(*layers) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', + nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +def _video_resnet(arch, pretrained=False, progress=True, **kwargs): + model = VideoResNet(**kwargs) + + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def r3d_18(pretrained=False, progress=True, **kwargs): + """Construct 18 layer Resnet3D model as in + https://arxiv.org/abs/1711.11248 + + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + + Returns: + nn.Module: R3D-18 network + """ + + return _video_resnet('r3d_18', + pretrained, progress, + block=BasicBlock, + conv_makers=[Conv3DSimple] * 4, + layers=[2, 2, 2, 2], + stem=BasicStem, **kwargs) + + +def mc3_18(pretrained=False, progress=True, **kwargs): + """Constructor for 18 layer Mixed Convolution network as in + https://arxiv.org/abs/1711.11248 + + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + + Returns: + nn.Module: MC3 Network definition + """ + return _video_resnet('mc3_18', + pretrained, progress, + block=BasicBlock, + conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, + layers=[2, 2, 2, 2], + stem=BasicStem, **kwargs) + + +def r2plus1d_18(pretrained=False, progress=True, **kwargs): + """Constructor for the 18 layer deep R(2+1)D network as in + https://arxiv.org/abs/1711.11248 + + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + + Returns: + nn.Module: R(2+1)D-18 network + """ + return _video_resnet('r2plus1d_18', + pretrained, progress, + block=BasicBlock, + conv_makers=[Conv2Plus1D] * 4, + layers=[2, 2, 2, 2], + stem=R2Plus1dStem, **kwargs) diff --git a/torchvision/models/video/video_stems.py b/torchvision/models/video/video_stems.py deleted file mode 100644 index 4813aa84390..00000000000 --- a/torchvision/models/video/video_stems.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch.nn as nn - - -def get_default_stem(use_pool1=False): - """The default conv-batchnorm-relu(-maxpool) stem - - Args: - use_pool1 (bool, optional): Should the stem include the default maxpool? Defaults to False. - - Returns: - nn.Sequential: Conv1 stem of resnet based models. - """ - - m = [ - nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), - padding=(1, 3, 3), bias=False), - nn.BatchNorm3d(64), - nn.ReLU(inplace=True)] - if use_pool1: - m.append(nn. MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)) - return nn.Sequential(*m) - - -def get_r2plus1d_stem(use_pool1=False): - """R(2+1)D stem is different than the default one as it uses separated 3D convolution - - Args: - use_pool1 (bool, optional): Should the stem contain pool1 layer. Defaults to False. - - Returns: - nn.Sequential: the stem of the conv-separated network. - """ - - m = [ - nn.Conv3d(3, 45, kernel_size=(1, 7, 7), - stride=(1, 2, 2), padding=(0, 3, 3), - bias=False), - nn.BatchNorm3d(45), - nn.ReLU(inplace=True), - nn.Conv3d(45, 64, kernel_size=(3, 1, 1), - stride=(1, 1, 1), padding=(1, 0, 0), - bias=False), - nn.BatchNorm3d(64), - nn.ReLU(inplace=True)] - - if use_pool1: - m.append(nn. MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)) - return nn.Sequential(*m) diff --git a/torchvision/models/video/video_trunk.py b/torchvision/models/video/video_trunk.py deleted file mode 100644 index e28d60c611b..00000000000 --- a/torchvision/models/video/video_trunk.py +++ /dev/null @@ -1,189 +0,0 @@ -import inspect -import torch -import torch.nn as nn - -from .video_stems import get_default_stem -from ._utils import Conv3DNoTemporal - - -BLOCK_CONFIG = { - 10: (1, 1, 1, 1), - 16: (2, 2, 2, 1), - 18: (2, 2, 2, 2), - 26: (2, 3, 4, 3), - 34: (3, 4, 6, 3), - 50: (3, 4, 6, 3), - 101: (3, 4, 23, 3), - 152: (3, 8, 36, 3) -} - - -class BasicBlock(nn.Module): - - expansion = 1 - - def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): - midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) - - super(BasicBlock, self).__init__() - self.conv1 = nn.Sequential( - conv_builder(inplanes, planes, midplanes, stride), - nn.BatchNorm3d(planes), - nn.ReLU(inplace=True) - ) - self.conv2 = nn.Sequential( - conv_builder(planes, planes, midplanes), - nn.BatchNorm3d(planes) - ) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.conv2(out) - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): - - super(Bottleneck, self).__init__() - midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) - - # 1x1x1 - self.conv1 = nn.Sequential( - nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), - nn.BatchNorm3d(planes), - nn.ReLU(inplace=True) - ) - # Second kernel - self.conv2 = nn.Sequential( - conv_builder(planes, planes, midplanes, stride), - nn.BatchNorm3d(planes), - nn.ReLU(inplace=True) - ) - - # 1x1x1 - self.conv3 = nn.Sequential( - nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), - nn.BatchNorm3d(planes * self.expansion) - ) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.conv2(out) - out = self.conv3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class VideoTrunkBuilder(nn.Module): - - def __init__(self, block, conv_makers, model_depth, - stem=None, - num_classes=400, - zero_init_residual=False): - """Generic resnet video generator. - - Args: - block (nn.Module): resnet building block - conv_makers (list(functions)): generator function for each layer - model_depth (int): depth of the model; supports traditional resnet depths . - stem (nn.Sequential, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. - num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. - zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. - """ - super(VideoTrunkBuilder, self).__init__() - layers = BLOCK_CONFIG[model_depth] - self.inplanes = 64 - - if stem is None: - self.conv1 = get_default_stem() - else: - self.conv1 = stem - - self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) - self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) - self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) - self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2) - - self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) - self.fc = nn.Linear(512 * block.expansion, num_classes) - - # init weights - self._initialize_weights() - - if zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) - - def forward(self, x): - x = self.conv1(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = self.avgpool(x) - # Flatten the layer to fc - x = x.flatten(1) - x = self.fc(x) - - return x - - def _make_layer(self, block, conv_builder, planes, blocks, stride=1): - downsample = None - - if stride != 1 or self.inplanes != planes * block.expansion: - ds_stride = conv_builder.get_downsample_stride(stride) - downsample = nn.Sequential( - nn.Conv3d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=ds_stride, bias=False), - nn.BatchNorm3d(planes * block.expansion) - ) - layers = [] - layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) - - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes, conv_builder)) - - return nn.Sequential(*layers) - - def _initialize_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv3d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', - nonlinearity='relu') - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm3d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0, 0.01) - nn.init.constant_(m.bias, 0)