Skip to content

Commit

Permalink
fix: atrous_rates for deeplabv3_mobilenet_v3_large (fixes #7956) (#8019)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
  • Loading branch information
nvs-abhilash and NicolasHug committed Oct 11, 2023
1 parent 7e2050f commit 70a8e05
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torchvision/models/segmentation/deeplabv3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, List, Optional
from typing import Any, Optional, Sequence

import torch
from torch import nn
Expand Down Expand Up @@ -46,9 +46,9 @@ class DeepLabV3(_SimpleSegmentationModel):


class DeepLabHead(nn.Sequential):
def __init__(self, in_channels: int, num_classes: int) -> None:
def __init__(self, in_channels: int, num_classes: int, atrous_rates: Sequence[int] = (12, 24, 36)) -> None:
super().__init__(
ASPP(in_channels, [12, 24, 36]),
ASPP(in_channels, atrous_rates),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
Expand Down Expand Up @@ -83,7 +83,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class ASPP(nn.Module):
def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
def __init__(self, in_channels: int, atrous_rates: Sequence[int], out_channels: int = 256) -> None:
super().__init__()
modules = []
modules.append(
Expand Down

0 comments on commit 70a8e05

Please sign in to comment.