Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions torchvision/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,12 @@ def __init__(
self.fc = nn.Linear(1024, num_classes)

if init_weights:
self._initialize_weights()

def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def _transform_input(self, x: Tensor) -> Tensor:
if self.transform_input:
Expand Down
14 changes: 6 additions & 8 deletions torchvision/models/mnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,7 @@ def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2)
]
self.layers = nn.Sequential(*layers)
self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes))
self._initialize_weights()

def forward(self, x: Tensor) -> Tensor:
x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3])
return self.classifier(x)

def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
Expand All @@ -149,6 +141,12 @@ def _initialize_weights(self) -> None:
nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
nn.init.zeros_(m.bias)

def forward(self, x: Tensor) -> Tensor:
x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3])
return self.classifier(x)

def _load_from_state_dict(
self,
state_dict: Dict,
Expand Down
3 changes: 0 additions & 3 deletions torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,6 @@ def __init__(self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), norm_l

self.conv = nn.Conv2d(layers[3], layers[4], kernel_size=1)

self._init_weights()

def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
Expand Down
24 changes: 10 additions & 14 deletions torchvision/models/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,20 +366,6 @@ def __init__(
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(in_features=current_width, out_features=num_classes)

# Init weights and good to go
self._reset_parameters()

def forward(self, x: Tensor) -> Tensor:
x = self.stem(x)
x = self.trunk_output(x)

x = self.avgpool(x)
x = x.flatten(start_dim=1)
x = self.fc(x)

return x

def _reset_parameters(self) -> None:
# Performs ResNet-style weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
Expand All @@ -393,6 +379,16 @@ def _reset_parameters(self) -> None:
nn.init.normal_(m.weight, mean=0.0, std=0.01)
nn.init.zeros_(m.bias)

def forward(self, x: Tensor) -> Tensor:
x = self.stem(x)
x = self.trunk_output(x)

x = self.avgpool(x)
x = x.flatten(start_dim=1)
x = self.fc(x)

return x


def _regnet(arch: str, block_params: BlockParams, pretrained: bool, progress: bool, **kwargs: Any) -> RegNet:
norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1))
Expand Down
25 changes: 11 additions & 14 deletions torchvision/models/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,17 @@ def __init__(
nn.Linear(4096, num_classes),
)
if init_weights:
self._initialize_weights()
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
Expand All @@ -59,19 +69,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.classifier(x)
return x

def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)


def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
layers: List[nn.Module] = []
Expand Down
25 changes: 11 additions & 14 deletions torchvision/models/video/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,17 @@ def __init__(
self.fc = nn.Linear(512 * block.expansion, num_classes)

# init weights
self._initialize_weights()
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)

if zero_init_residual:
for m in self.modules():
Expand Down Expand Up @@ -270,19 +280,6 @@ def _make_layer(

return nn.Sequential(*layers)

def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)


def _video_resnet(arch: str, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet:
model = VideoResNet(**kwargs)
Expand Down
19 changes: 9 additions & 10 deletions torchvision/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
self.dropout_1 = nn.Dropout(dropout)
self.linear_2 = nn.Linear(mlp_dim, in_dim)
self.dropout_2 = nn.Dropout(dropout)
self._init_weights()

def _init_weights(self):
nn.init.xavier_uniform_(self.linear_1.weight)
nn.init.xavier_uniform_(self.linear_2.weight)
nn.init.normal_(self.linear_1.bias, std=1e-6)
Expand Down Expand Up @@ -211,28 +209,29 @@ def __init__(
heads_layers["head"] = nn.Linear(representation_size, num_classes)

self.heads = nn.Sequential(heads_layers)
self._init_weights()

def _init_weights(self):
if isinstance(self.conv_proj, nn.Conv2d):
# Init the patchify stem
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.conv_proj.bias)
else:
if self.conv_proj.bias is not None:
nn.init.zeros_(self.conv_proj.bias)
elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
# Init the last 1x1 conv of the conv stem
nn.init.normal_(
self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
)
nn.init.zeros_(self.conv_proj.conv_last.bias)
if self.conv_proj.conv_last.bias is not None:
nn.init.zeros_(self.conv_proj.conv_last.bias)

if hasattr(self.heads, "pre_logits"):
if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
fan_in = self.heads.pre_logits.in_features
nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.heads.pre_logits.bias)

nn.init.zeros_(self.heads.head.weight)
nn.init.zeros_(self.heads.head.bias)
if isinstance(self.heads.head, nn.Linear):
nn.init.zeros_(self.heads.head.weight)
nn.init.zeros_(self.heads.head.bias)

def _process_input(self, x: torch.Tensor) -> torch.Tensor:
n, c, h, w = x.shape
Expand Down