Skip to content

Commit

Permalink
feat: parameterize atrous_rates for deeplabv3 in DeepLabHead.
Browse files Browse the repository at this point in the history
  • Loading branch information
nvs-abhilash committed Oct 9, 2023
1 parent 16967f0 commit 17b530d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchvision/models/segmentation/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ class DeepLabV3(_SimpleSegmentationModel):


class DeepLabHead(nn.Sequential):
def __init__(self, in_channels: int, num_classes: int) -> None:
def __init__(self, in_channels: int, num_classes: int, atrous_rates: List[int] = [12, 24, 36]) -> None:
super().__init__(
ASPP(in_channels, [12, 24, 36]),
ASPP(in_channels, atrous_rates),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
Expand Down

0 comments on commit 17b530d

Please sign in to comment.