Skip to content

Commit

Permalink
Add argument for low-res datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
vturrisi committed Jun 15, 2024
1 parent b69b4bd commit 0abeaab
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions solo/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def __init__(self, cfg: omegaconf.DictConfig):
self.features_dim: int = self.backbone.inplanes
# remove fc layer
self.backbone.fc = nn.Identity()
cifar = cfg.data.dataset in ["cifar10", "cifar100"]
if cifar:
low_res = cfg.data.dataset in ["cifar10", "cifar100"] or cfg.method_kwargs.get("low_res", False)
if low_res:
self.backbone.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=2, bias=False
)
Expand Down Expand Up @@ -644,8 +644,8 @@ def __init__(
if self.backbone_name.startswith("resnet"):
# remove fc layer
self.momentum_backbone.fc = nn.Identity()
cifar = cfg.data.dataset in ["cifar10", "cifar100"]
if cifar:
low_res = cfg.data.dataset in ["cifar10", "cifar100"] or cfg.method_kwargs.get("low_res", False)
if low_res:
self.momentum_backbone.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=2, bias=False
)
Expand Down

0 comments on commit 0abeaab

Please sign in to comment.