Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate SAM (segment anything) encoder with Unet #757

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,19 @@ Note: In the official github repo the s0 variant has additional num_conv_branche
</div>
</details>

<details>
<summary style="margin-left: 25px;">SAM</summary>
<div style="margin-left: 25px;">

| Encoder | Weights | Params, M |
|-----------|:--------:|:---------:|
| sam-vit_b | sa-1b | 91M |
| sam-vit_l | sa-1b | 308M |
| sam-vit_h | sa-1b | 636M |

</div>
</details>


\* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)).

Expand Down
13 changes: 13 additions & 0 deletions docs/encoders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,16 @@ MobileOne
+-----------------+----------+------------+
| mobileone\_s4 | imagenet | 13.6M |
+-----------------+----------+------------+

SAM
~~~~~~~~~~~~~~~~~~~~~

+-----------------+----------+------------+
| Encoder | Weights | Params, M |
+=================+==========+============+
| sam-vit_b | sa-1b | 91M |
+-----------------+----------+------------+
| sam-vit_l | sa-1b | 308M |
+-----------------+----------+------------+
| sam-vit_h | sa-1b | 636M |
+-----------------+----------+------------+
2 changes: 0 additions & 2 deletions docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,3 @@ DeepLabV3
DeepLabV3+
~~~~~~~~~~
.. autoclass:: segmentation_models_pytorch.DeepLabV3Plus


1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ torchvision>=0.5.0
pretrainedmodels==0.7.4
efficientnet-pytorch==0.7.1
timm==0.9.2
segment-anything-py==1.0

tqdm
pillow
5 changes: 4 additions & 1 deletion segmentation_models_pytorch/decoders/pan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(
activation: Optional[Union[str, callable]] = None,
upsampling: int = 4,
aux_params: Optional[dict] = None,
encoder_kwargs: Optional[dict] = None,
encoder_depth: int = 5,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, this is probably out of the scope of this PR. Did you test that PAN works with a depth other than 5?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me try it

Copy link
Author

@Rusteam Rusteam May 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh, I have changed this by mistake, although this change should not hurt. Do you want me to remove it? It does not work with SAM encoder

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, could you remove this code pls?

):
super().__init__()

Expand All @@ -67,9 +69,10 @@ def __init__(
self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=5,
depth=encoder_depth,
weights=encoder_weights,
output_stride=encoder_output_stride,
**({} if encoder_kwargs is None else encoder_kwargs),
Rusteam marked this conversation as resolved.
Show resolved Hide resolved
)

self.decoder = PANDecoder(
Expand Down
41 changes: 31 additions & 10 deletions segmentation_models_pytorch/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .resnet import resnet_encoders
from .dpn import dpn_encoders
from .sam import sam_vit_encoders, SamVitEncoder
from .vgg import vgg_encoders
from .senet import senet_encoders
from .densenet import densenet_encoders
Expand Down Expand Up @@ -46,6 +47,34 @@
encoders.update(timm_gernet_encoders)
encoders.update(mix_transformer_encoders)
encoders.update(mobileone_encoders)
encoders.update(sam_vit_encoders)


def get_pretrained_settings(encoders: dict, encoder_name: str, weights: str) -> dict:
"""Get pretrained settings for encoder from encoders collection.

Args:
encoders: collection of encoders
encoder_name: name of encoder in collection
weights: one of ``None`` (random initialization), ``imagenet`` or other pretrained settings

Returns:
pretrained settings for encoder

Raises:
KeyError: in case of wrong encoder name or pretrained settings name
"""
try:
settings = encoders[encoder_name]["pretrained_settings"][weights]
except KeyError:
raise KeyError(
"Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
weights,
encoder_name,
list(encoders[encoder_name]["pretrained_settings"].keys()),
)
)
return settings


def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
Expand All @@ -69,19 +98,11 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **

params = encoders[name]["params"]
params.update(depth=depth)
params.update(kwargs)
encoder = Encoder(**params)

if weights is not None:
try:
settings = encoders[name]["pretrained_settings"][weights]
except KeyError:
raise KeyError(
"Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
weights,
name,
list(encoders[name]["pretrained_settings"].keys()),
)
)
settings = get_pretrained_settings(encoders, name, weights)
encoder.load_state_dict(model_zoo.load_url(settings["url"]))

encoder.set_in_channels(in_channels, pretrained=weights is not None)
Expand Down
5 changes: 0 additions & 5 deletions segmentation_models_pytorch/encoders/_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
import torch
import torch.nn as nn
from typing import List
from collections import OrderedDict

from . import _utils as utils


Expand Down
106 changes: 106 additions & 0 deletions segmentation_models_pytorch/encoders/sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import math
import warnings
from typing import Mapping, Any

import torch
from segment_anything.modeling import ImageEncoderViT

from segmentation_models_pytorch.encoders._base import EncoderMixin


class SamVitEncoder(EncoderMixin, ImageEncoderViT):
def __init__(self, **kwargs):
self._vit_depth = kwargs.pop("vit_depth")
self._encoder_depth = kwargs.get("depth", 5)
kwargs.update({"depth": self._vit_depth})
super().__init__(**kwargs)
self._out_chans = kwargs.get("out_chans", 256)
self._patch_size = kwargs.get("patch_size", 16)
self._validate()

@property
def output_stride(self):
return 32

def _get_scale_factor(self) -> float:
"""Input image will be downscale by this factor"""
return int(math.log(self._patch_size, 2))

def _validate(self):
# check vit depth
if self._vit_depth not in [12, 24, 32]:
raise ValueError(f"vit_depth must be one of [12, 24, 32], got {self._vit_depth}")
# check output
scale_factor = self._get_scale_factor()
if scale_factor != self._encoder_depth:
raise ValueError(
f"With patch_size={self._patch_size} and depth={self._encoder_depth}, "
"spatial dimensions of model output will not match input spatial dimensions. "
"It is recommended to set encoder depth=4 with default vit patch_size=16."
)

@property
def out_channels(self):
# Fill up with leading zeros to be used in Unet
scale_factor = self._get_scale_factor()
return [0] * scale_factor + [self._out_chans]

def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
# Return a list of tensors to match other encoders
return [x, super().forward(x)]
Rusteam marked this conversation as resolved.
Show resolved Hide resolved

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) -> None:
# Exclude mask_decoder and prompt encoder weights
# and remove 'image_encoder.' prefix
state_dict = {
k.replace("image_encoder.", ""): v
for k, v in state_dict.items()
if not k.startswith("mask_decoder") and not k.startswith("prompt_encoder")
}
missing, unused = super().load_state_dict(state_dict, strict=False)
if len(missing) + len(unused) > 0:
n_loaded = len(state_dict) - len(missing) - len(unused)
warnings.warn(
f"Only {n_loaded} out of pretrained {len(state_dict)} SAM image encoder modules are loaded. "
f"Missing modules: {missing}. Unused modules: {unused}."
)


sam_vit_encoders = {
"sam-vit_h": {
"encoder": SamVitEncoder,
"pretrained_settings": {
"sa-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"},
},
"params": dict(
embed_dim=1280,
vit_depth=32,
num_heads=16,
global_attn_indexes=[7, 15, 23, 31],
),
},
"sam-vit_l": {
"encoder": SamVitEncoder,
"pretrained_settings": {
"sa-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth"},
},
"params": dict(
embed_dim=1024,
vit_depth=24,
num_heads=16,
global_attn_indexes=[5, 11, 17, 23],
),
},
"sam-vit_b": {
"encoder": SamVitEncoder,
"pretrained_settings": {
"sa-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"},
},
"params": dict(
embed_dim=768,
vit_depth=12,
num_heads=12,
global_attn_indexes=[2, 5, 8, 11],
),
},
}
Empty file added tests/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def get_encoders():
"resnext101_32x16d",
"resnext101_32x32d",
"resnext101_32x48d",
"sam-vit_h",
"sam-vit_l",
"sam-vit_b",
]
encoders = smp.encoders.get_encoder_names()
encoders = [e for e in encoders if e not in exclude_encoders]
Expand Down
74 changes: 74 additions & 0 deletions tests/test_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
import torch

import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_encoder
from tests.test_models import get_sample, _test_forward, _test_forward_backward


@pytest.mark.parametrize("encoder_name", ["sam-vit_b", "sam-vit_l"])
@pytest.mark.parametrize("img_size", [64, 128])
@pytest.mark.parametrize("patch_size,depth", [(8, 3), (16, 4)])
@pytest.mark.parametrize("vit_depth", [12, 24])
def test_sam_encoder(encoder_name, img_size, patch_size, depth, vit_depth):
encoder = get_encoder(encoder_name, img_size=img_size, patch_size=patch_size, depth=depth, vit_depth=vit_depth)
assert encoder.output_stride == 32

sample = torch.ones(1, 3, img_size, img_size)
with torch.no_grad():
out = encoder(sample)

expected_patches = img_size // patch_size
assert out[-1].size() == torch.Size([1, 256, expected_patches, expected_patches])


def test_sam_encoder_validation_error():
with pytest.raises(ValueError):
get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=5, vit_depth=12)
get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=4, vit_depth=None)
get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=4, vit_depth=6)


@pytest.mark.skip(reason="Decoder has been removed, keeping this for future integration")
@pytest.mark.parametrize("decoder_multiclass_output", [True, False])
@pytest.mark.parametrize("n_classes", [1, 3])
def test_sam(decoder_multiclass_output, n_classes):
model = smp.SAM(
"sam-vit_b",
encoder_weights=None,
weights=None,
image_size=64,
decoder_multimask_output=decoder_multiclass_output,
classes=n_classes,
)
sample = get_sample(smp.SAM)
model.eval()

_test_forward(model, sample, test_shape=True)
_test_forward_backward(model, sample, test_shape=True)


@pytest.mark.parametrize("model_class", [smp.Unet])
@pytest.mark.parametrize("decoder_channels,encoder_depth", [([64, 32, 16, 8], 4), ([64, 32, 16, 8], 4)])
def test_sam_encoder_arch(model_class, decoder_channels, encoder_depth):
img_size = 1024
model = model_class(
"sam-vit_b",
encoder_weights=None,
encoder_depth=encoder_depth,
decoder_channels=decoder_channels,
)
smp = torch.ones(1, 3, img_size, img_size)
_test_forward_backward(model, smp, test_shape=True)


@pytest.mark.skip(reason="Run this test manually as it needs to download weights")
def test_sam_weights():
smp.create_model("sam", encoder_name="sam-vit_b", encoder_weights=None, weights="sa-1b")


@pytest.mark.skip(reason="Run this test manually as it needs to download weights")
def test_sam_encoder_weights():
smp.create_model(
"unet", encoder_name="sam-vit_b", encoder_depth=4, encoder_weights="sa-1b", decoder_channels=[64, 32, 16, 8]
)