From 4beb5711ae00f1ccfe78e8f5966e3e82e0ee71af Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Tue, 2 May 2023 16:44:28 +0400 Subject: [PATCH 01/22] add sam encoder and decoder --- .pre-commit-config.yaml | 4 +- segmentation_models_pytorch/__init__.py | 1 + .../decoders/sam/__init__.py | 1 + .../decoders/sam/model.py | 159 ++++++++++++++++++ tests/test_models.py | 14 +- 5 files changed, 176 insertions(+), 3 deletions(-) create mode 100644 segmentation_models_pytorch/decoders/sam/__init__.py create mode 100644 segmentation_models_pytorch/decoders/sam/model.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b4749e04..b545f9b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,8 +4,8 @@ repos: hooks: - id: black args: [ --config=pyproject.toml ] - - repo: https://gitlab.com/pycqa/flake8 - rev: 4.0.1 + - repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 hooks: - id: flake8 args: [ --config=.flake8 ] diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index 1ac9e1fb..788d78f7 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -12,6 +12,7 @@ from .decoders.pspnet import PSPNet from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus from .decoders.pan import PAN +from .decoders.sam import SAM from .__version__ import __version__ diff --git a/segmentation_models_pytorch/decoders/sam/__init__.py b/segmentation_models_pytorch/decoders/sam/__init__.py new file mode 100644 index 00000000..dfc17c68 --- /dev/null +++ b/segmentation_models_pytorch/decoders/sam/__init__.py @@ -0,0 +1 @@ +from .model import SAM diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py new file mode 100644 index 00000000..de2d0c56 --- /dev/null +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -0,0 +1,159 @@ +from typing import Optional, Union, List, Tuple + +import torch +from torch.nn import functional as F + +from segmentation_models_pytorch.base import ( + SegmentationModel, + SegmentationHead, + ClassificationHead, +) + + +class SAM(SegmentationModel): + """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* + and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial + resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation* + for fusing decoder blocks with skip connections. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. + Length of the list should be the same as **encoder_depth** + decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. + Available options are **True, False, "inplace"** + decoder_attention_type: Attention module used in decoder of the model. Available options are + **None** and **scse** (https://arxiv.org/abs/1808.08127). + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. + Default is **None** + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" + (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: Unet + + .. _Unet: + https://arxiv.org/abs/1505.04597 + + """ + + def __init__( + self, + encoder_name: str = "vit_h", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_use_batchnorm: bool = True, + decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_attention_type: Optional[str] = None, + in_channels: int = 3, + image_size: int = 1024, + vit_patch_size: int = 16, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + aux_params: Optional[dict] = None, + ): + super().__init__() + from segment_anything import sam_model_registry + + sam = sam_model_registry[encoder_name]( + checkpoint=encoder_weights, image_size=image_size, vit_patch_size=vit_patch_size + ) + + self.pixel_mean = sam.pixel_mean + self.pixel_std = sam.pixel_std + + self.encoder = sam.image_encoder + self.encoder.output_stride = 32 # TODO fix this + self.prompt_encoder = sam.prompt_encoder + + self.decoder = sam.mask_decoder + + self.segmentation_head = SegmentationHead( + in_channels=decoder_channels[-1], + out_channels=classes, + activation=activation, + kernel_size=3, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) + else: + self.classification_head = None + + self.name = "sam-{}".format(encoder_name) + self.initialize() + + def preprocess(self, x): + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.encoder.img_size - h + padw = self.encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.encoder.img_size, self.encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def forward(self, x): + img_size = x.shape[-2:] + x = torch.stack([self.preprocess(img) for img in x]) + features = self.encoder(x) + sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None) + low_res_masks, iou_preidctions = self.decoder( + image_embeddings=features, + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=False, + ) + masks = self.postprocess_masks(low_res_masks, input_size=img_size, original_size=img_size) + return masks diff --git a/tests/test_models.py b/tests/test_models.py index c2e6d941..5dd283e6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,7 +26,7 @@ def get_encoders(): def get_sample(model_class): - if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus, smp.MAnet]: + if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus, smp.MAnet, smp.SAM]: sample = torch.ones([1, 3, 64, 64]) elif model_class == smp.PAN: sample = torch.ones([2, 3, 256, 256]) @@ -136,5 +136,17 @@ def test_dilation(encoder_name): assert shapes == [64, 32, 16, 8, 4, 4] # last downsampling replaced with dilation +@pytest.mark.parametrize("encoder_name", ["vit_b", "vit_l"]) +@pytest.mark.parametrize("image_size", [64, 128]) +def test_sam(encoder_name, image_size): + model_class = smp.SAM + model = model_class(encoder_name, encoder_weights=None, image_size=image_size) + sample = get_sample(model_class) + model.eval() + + _test_forward(model, sample, test_shape=True) + _test_forward_backward(model, sample, test_shape=True) + + if __name__ == "__main__": pytest.main([__file__]) From 3dc3235d8b91564d49e2cf27cece47bb2a99ee9e Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Wed, 3 May 2023 12:22:19 +0400 Subject: [PATCH 02/22] refactor sam vit encoder to common format --- .../decoders/sam/model.py | 21 +++++--- .../encoders/__init__.py | 10 +++- segmentation_models_pytorch/encoders/_base.py | 5 -- segmentation_models_pytorch/encoders/sam.py | 50 +++++++++++++++++++ tests/__init__.py | 0 tests/test_models.py | 12 ----- tests/test_sam.py | 34 +++++++++++++ 7 files changed, 107 insertions(+), 25 deletions(-) create mode 100644 segmentation_models_pytorch/encoders/sam.py create mode 100644 tests/__init__.py create mode 100644 tests/test_sam.py diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index de2d0c56..35f0824e 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -8,6 +8,7 @@ SegmentationHead, ClassificationHead, ) +from segmentation_models_pytorch.encoders import get_encoder class SAM(SegmentationModel): @@ -56,9 +57,9 @@ class SAM(SegmentationModel): def __init__( self, - encoder_name: str = "vit_h", - encoder_depth: int = 5, - encoder_weights: Optional[str] = "imagenet", + encoder_name: str = "sam-vit_h", + encoder_depth: int = None, + encoder_weights: Optional[str] = "sam-vit_h", decoder_use_batchnorm: bool = True, decoder_channels: List[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, @@ -72,15 +73,21 @@ def __init__( super().__init__() from segment_anything import sam_model_registry - sam = sam_model_registry[encoder_name]( + sam = sam_model_registry[encoder_name[4:]]( checkpoint=encoder_weights, image_size=image_size, vit_patch_size=vit_patch_size ) self.pixel_mean = sam.pixel_mean self.pixel_std = sam.pixel_std - self.encoder = sam.image_encoder - self.encoder.output_stride = 32 # TODO fix this + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + img_size=image_size, + patch_size=vit_patch_size, + ) self.prompt_encoder = sam.prompt_encoder self.decoder = sam.mask_decoder @@ -97,7 +104,7 @@ def __init__( else: self.classification_head = None - self.name = "sam-{}".format(encoder_name) + self.name = encoder_name self.initialize() def preprocess(self, x): diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index 2a3ff5c0..a2a23983 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -4,6 +4,7 @@ from .resnet import resnet_encoders from .dpn import dpn_encoders +from .sam import sam_vit_encoders, SamVitEncoder from .vgg import vgg_encoders from .senet import senet_encoders from .densenet import densenet_encoders @@ -46,6 +47,7 @@ encoders.update(timm_gernet_encoders) encoders.update(mix_transformer_encoders) encoders.update(mobileone_encoders) +encoders.update(sam_vit_encoders) def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): @@ -68,7 +70,13 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, ** raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys()))) params = encoders[name]["params"] - params.update(depth=depth) + if name.startswith("sam-"): + params.update(**kwargs) + params.update(dict(name=name[4:])) + if depth is not None: + params.update(depth=depth) + else: + params.update(depth=depth) encoder = Encoder(**params) if weights is not None: diff --git a/segmentation_models_pytorch/encoders/_base.py b/segmentation_models_pytorch/encoders/_base.py index aab838f1..fee8d177 100644 --- a/segmentation_models_pytorch/encoders/_base.py +++ b/segmentation_models_pytorch/encoders/_base.py @@ -1,8 +1,3 @@ -import torch -import torch.nn as nn -from typing import List -from collections import OrderedDict - from . import _utils as utils diff --git a/segmentation_models_pytorch/encoders/sam.py b/segmentation_models_pytorch/encoders/sam.py new file mode 100644 index 00000000..8d5438a6 --- /dev/null +++ b/segmentation_models_pytorch/encoders/sam.py @@ -0,0 +1,50 @@ +from segment_anything.modeling import ImageEncoderViT + +from segmentation_models_pytorch.encoders._base import EncoderMixin + + +class SamVitEncoder(EncoderMixin, ImageEncoderViT): + def __init__(self, name: str, **kwargs): + super().__init__(**kwargs) + self._name = name + self._depth = kwargs["depth"] + + +sam_vit_encoders = { + "sam-vit_h": { + "encoder": SamVitEncoder, + "pretrained_settings": { + "sam-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"}, + }, + "params": dict( + embed_dim=1280, + depth=32, + num_heads=16, + global_attn_indexes=[7, 15, 23, 31], + ), + }, + "sam-vit_l": { + "encoder": SamVitEncoder, + "pretrained_settings": { + "sam-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth"}, + }, + "params": dict( + embed_dim=1024, + depth=24, + num_heads=16, + global_attn_indexes=[5, 11, 17, 23], + ), + }, + "sam-vit_b": { + "encoder": SamVitEncoder, + "pretrained_settings": { + "sam-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"}, + }, + "params": dict( + embed_dim=768, + depth=12, + num_heads=12, + global_attn_indexes=[2, 5, 8, 11], + ), + }, +} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_models.py b/tests/test_models.py index 5dd283e6..cfdf3a11 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -136,17 +136,5 @@ def test_dilation(encoder_name): assert shapes == [64, 32, 16, 8, 4, 4] # last downsampling replaced with dilation -@pytest.mark.parametrize("encoder_name", ["vit_b", "vit_l"]) -@pytest.mark.parametrize("image_size", [64, 128]) -def test_sam(encoder_name, image_size): - model_class = smp.SAM - model = model_class(encoder_name, encoder_weights=None, image_size=image_size) - sample = get_sample(model_class) - model.eval() - - _test_forward(model, sample, test_shape=True) - _test_forward_backward(model, sample, test_shape=True) - - if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_sam.py b/tests/test_sam.py new file mode 100644 index 00000000..a48d52fe --- /dev/null +++ b/tests/test_sam.py @@ -0,0 +1,34 @@ +import pytest +import torch + +import segmentation_models_pytorch as smp +from segmentation_models_pytorch.encoders import get_encoder +from tests.test_models import get_sample, _test_forward, _test_forward_backward + + +@pytest.mark.parametrize("encoder_name", ["sam-vit_b", "sam-vit_l"]) +@pytest.mark.parametrize("img_size", [64, 128]) +@pytest.mark.parametrize("patch_size", [8, 16]) +def test_sam_encoder(encoder_name, img_size, patch_size): + encoder = get_encoder(encoder_name, img_size=img_size, patch_size=patch_size) + assert encoder._name == encoder_name[4:] + assert encoder.output_stride == 32 + + sample = torch.ones(1, 3, img_size, img_size) + with torch.no_grad(): + out = encoder(sample) + + expected_patches = img_size // patch_size + assert out.size() == torch.Size([1, 256, expected_patches, expected_patches]) + + +@pytest.mark.parametrize("encoder_name", ["sam-vit_b"]) +@pytest.mark.parametrize("image_size", [64]) +def test_sam(encoder_name, image_size): + model_class = smp.SAM + model = model_class(encoder_name, encoder_weights=None, image_size=image_size) + sample = get_sample(model_class) + model.eval() + + _test_forward(model, sample, test_shape=True) + _test_forward_backward(model, sample, test_shape=True) From 85565ce3d9f612303403536a531ccb85c699eb13 Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Wed, 3 May 2023 12:52:17 +0400 Subject: [PATCH 03/22] refactor sam decoder init --- .../decoders/sam/model.py | 39 +++++++++++++------ tests/test_sam.py | 4 +- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index 35f0824e..8e237741 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -1,6 +1,7 @@ from typing import Optional, Union, List, Tuple import torch +from segment_anything.modeling import MaskDecoder, TwoWayTransformer, PromptEncoder from torch.nn import functional as F from segmentation_models_pytorch.base import ( @@ -61,7 +62,7 @@ def __init__( encoder_depth: int = None, encoder_weights: Optional[str] = "sam-vit_h", decoder_use_batchnorm: bool = True, - decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_channels: List[int] = 256, decoder_attention_type: Optional[str] = None, in_channels: int = 3, image_size: int = 1024, @@ -71,14 +72,9 @@ def __init__( aux_params: Optional[dict] = None, ): super().__init__() - from segment_anything import sam_model_registry - sam = sam_model_registry[encoder_name[4:]]( - checkpoint=encoder_weights, image_size=image_size, vit_patch_size=vit_patch_size - ) - - self.pixel_mean = sam.pixel_mean - self.pixel_std = sam.pixel_std + self.pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) + self.pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) self.encoder = get_encoder( encoder_name, @@ -87,13 +83,32 @@ def __init__( weights=encoder_weights, img_size=image_size, patch_size=vit_patch_size, + out_chans=decoder_channels, ) - self.prompt_encoder = sam.prompt_encoder - self.decoder = sam.mask_decoder + image_embedding_size = image_size // vit_patch_size + self.prompt_encoder = PromptEncoder( + embed_dim=decoder_channels, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ) + + self.decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=decoder_channels, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=decoder_channels, + iou_head_depth=3, + iou_head_hidden_dim=256, + ) self.segmentation_head = SegmentationHead( - in_channels=decoder_channels[-1], + in_channels=decoder_channels, out_channels=classes, activation=activation, kernel_size=3, @@ -155,7 +170,7 @@ def forward(self, x): x = torch.stack([self.preprocess(img) for img in x]) features = self.encoder(x) sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None) - low_res_masks, iou_preidctions = self.decoder( + low_res_masks, iou_predictions = self.decoder( image_embeddings=features, image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, diff --git a/tests/test_sam.py b/tests/test_sam.py index a48d52fe..992bfa6d 100644 --- a/tests/test_sam.py +++ b/tests/test_sam.py @@ -22,8 +22,8 @@ def test_sam_encoder(encoder_name, img_size, patch_size): assert out.size() == torch.Size([1, 256, expected_patches, expected_patches]) -@pytest.mark.parametrize("encoder_name", ["sam-vit_b"]) -@pytest.mark.parametrize("image_size", [64]) +@pytest.mark.parametrize("encoder_name", ["sam-vit_b", "sam-vit_l"]) +@pytest.mark.parametrize("image_size", [64, 128]) def test_sam(encoder_name, image_size): model_class = smp.SAM model = model_class(encoder_name, encoder_weights=None, image_size=image_size) From 11436681ab15d3168a38d06d2c34316077d4972f Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Wed, 3 May 2023 15:03:08 +0400 Subject: [PATCH 04/22] add segmentation head to sam --- .../decoders/sam/model.py | 44 ++++++++----------- tests/test_sam.py | 22 ++++++---- 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index 8e237741..f4a9fe81 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -7,33 +7,25 @@ from segmentation_models_pytorch.base import ( SegmentationModel, SegmentationHead, - ClassificationHead, ) from segmentation_models_pytorch.encoders import get_encoder class SAM(SegmentationModel): - """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* - and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial - resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation* - for fusing decoder blocks with skip connections. + """SAM_ (Segment Anything Model) is a visual transformer based encoder-decoder segmentation + model that can be used to produce high quality segmentation masks from images and prompts. + Consists of *image encoder*, *prompt encoder* and *mask decoder*. *Segmentation head* is + added after the *mask decoder* to define the final number of classes for the output mask. Args: encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) to extract features of different spatial resolution - encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + encoder_depth: A number of stages used in encoder in range [6, 24]. Each stage generate features two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). Default is 5 - encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and - other pretrained weights (see table with available weights for each encoder_name) - decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. - Length of the list should be the same as **encoder_depth** - decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** - decoder_attention_type: Attention module used in decoder of the model. Available options are - **None** and **scse** (https://arxiv.org/abs/1808.08127). + encoder_weights: One of **None** (random initialization), **"sa-1b"** (pre-training on SA-1B dataset). + decoder_channels: How many output channels image encoder will have. Default is 256. in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. @@ -49,10 +41,10 @@ class SAM(SegmentationModel): (could be **None** to return logits) Returns: - ``torch.nn.Module``: Unet + ``torch.nn.Module``: SAM - .. _Unet: - https://arxiv.org/abs/1505.04597 + .. _SAM: + https://github.com/facebookresearch/segment-anything """ @@ -61,9 +53,8 @@ def __init__( encoder_name: str = "sam-vit_h", encoder_depth: int = None, encoder_weights: Optional[str] = "sam-vit_h", - decoder_use_batchnorm: bool = True, decoder_channels: List[int] = 256, - decoder_attention_type: Optional[str] = None, + decoder_multimask_output: bool = True, in_channels: int = 3, image_size: int = 1024, vit_patch_size: int = 16, @@ -106,18 +97,18 @@ def __init__( iou_head_depth=3, iou_head_hidden_dim=256, ) + self._decoder_multiclass_output = decoder_multimask_output self.segmentation_head = SegmentationHead( - in_channels=decoder_channels, + in_channels=3 if decoder_multimask_output else 1, out_channels=classes, activation=activation, kernel_size=3, ) if aux_params is not None: - self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params) - else: - self.classification_head = None + raise NotImplementedError("Auxiliary output is not supported yet") + self.classification_head = None self.name = encoder_name self.initialize() @@ -175,7 +166,8 @@ def forward(self, x): image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, - multimask_output=False, + multimask_output=self._decoder_multiclass_output, ) masks = self.postprocess_masks(low_res_masks, input_size=img_size, original_size=img_size) - return masks + output = self.segmentation_head(masks) + return output diff --git a/tests/test_sam.py b/tests/test_sam.py index 992bfa6d..1d288296 100644 --- a/tests/test_sam.py +++ b/tests/test_sam.py @@ -9,8 +9,9 @@ @pytest.mark.parametrize("encoder_name", ["sam-vit_b", "sam-vit_l"]) @pytest.mark.parametrize("img_size", [64, 128]) @pytest.mark.parametrize("patch_size", [8, 16]) -def test_sam_encoder(encoder_name, img_size, patch_size): - encoder = get_encoder(encoder_name, img_size=img_size, patch_size=patch_size) +@pytest.mark.parametrize("depth", [6, 24, None]) +def test_sam_encoder(encoder_name, img_size, patch_size, depth): + encoder = get_encoder(encoder_name, img_size=img_size, patch_size=patch_size, depth=depth) assert encoder._name == encoder_name[4:] assert encoder.output_stride == 32 @@ -22,12 +23,17 @@ def test_sam_encoder(encoder_name, img_size, patch_size): assert out.size() == torch.Size([1, 256, expected_patches, expected_patches]) -@pytest.mark.parametrize("encoder_name", ["sam-vit_b", "sam-vit_l"]) -@pytest.mark.parametrize("image_size", [64, 128]) -def test_sam(encoder_name, image_size): - model_class = smp.SAM - model = model_class(encoder_name, encoder_weights=None, image_size=image_size) - sample = get_sample(model_class) +@pytest.mark.parametrize("decoder_multiclass_output", [True, False]) +@pytest.mark.parametrize("n_classes", [1, 3]) +def test_sam(decoder_multiclass_output, n_classes): + model = smp.SAM( + "sam-vit_b", + encoder_weights=None, + image_size=64, + decoder_multimask_output=decoder_multiclass_output, + classes=n_classes, + ) + sample = get_sample(smp.SAM) model.eval() _test_forward(model, sample, test_shape=True) From 48033cb969f780d4565e14adf5e5127a372c9f86 Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Wed, 3 May 2023 15:17:43 +0400 Subject: [PATCH 05/22] add sam to encoder and model docs --- docs/encoders.rst | 13 +++++++++++++ docs/models.rst | 3 +++ 2 files changed, 16 insertions(+) diff --git a/docs/encoders.rst b/docs/encoders.rst index 32587c8d..9c65e8ec 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -361,3 +361,16 @@ MobileOne +-----------------+----------+------------+ | mobileone\_s4 | imagenet | 13.6M | +-----------------+----------+------------+ + +SAM +~~~~~~~~~~~~~~~~~~~~~ + ++-----------------+----------+------------+ +| Encoder | Weights | Params, M | ++=================+==========+============+ +| sam-vit_b | sa-1b | 91M | ++-----------------+----------+------------+ +| sam-vit_l | sa-1b | 308M | ++-----------------+----------+------------+ +| sam-vit_h | sa-1b | 636M | ++-----------------+----------+------------+ diff --git a/docs/models.rst b/docs/models.rst index 47de61ee..06c1f3ce 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -37,4 +37,7 @@ DeepLabV3+ ~~~~~~~~~~ .. autoclass:: segmentation_models_pytorch.DeepLabV3Plus +SAM +~~~~~~~~~~ +.. autoclass:: segmentation_models_pytorch.SAM From 1f1eacaa514529c2b0239ae26f7ad44de878469a Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Wed, 3 May 2023 15:26:53 +0400 Subject: [PATCH 06/22] remove sam encoders from test_models --- tests/test_models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index cfdf3a11..c5dada3f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -14,6 +14,9 @@ def get_encoders(): "resnext101_32x16d", "resnext101_32x32d", "resnext101_32x48d", + "sam-vit_h", + "sam-vit_l", + "sam-vit_b", ] encoders = smp.encoders.get_encoder_names() encoders = [e for e in encoders if e not in exclude_encoders] From f37c9b3b0094e8133f7f8cf3426767b6893ebf25 Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Thu, 4 May 2023 14:07:40 +0400 Subject: [PATCH 07/22] wip weights --- segmentation_models_pytorch/__init__.py | 1 + .../decoders/sam/model.py | 29 +++++++++++++- .../encoders/__init__.py | 38 ++++++++++++++----- segmentation_models_pytorch/encoders/sam.py | 6 +-- tests/test_sam.py | 5 +++ 5 files changed, 64 insertions(+), 15 deletions(-) diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index 788d78f7..ccf56715 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -43,6 +43,7 @@ def create_model( DeepLabV3, DeepLabV3Plus, PAN, + SAM, ] archs_dict = {a.__name__.lower(): a for a in archs} try: diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index f4a9fe81..953247c4 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -1,14 +1,22 @@ +import logging from typing import Optional, Union, List, Tuple import torch from segment_anything.modeling import MaskDecoder, TwoWayTransformer, PromptEncoder from torch.nn import functional as F +from torch.utils import model_zoo from segmentation_models_pytorch.base import ( SegmentationModel, SegmentationHead, ) -from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.encoders import get_encoder, sam_vit_encoders, get_pretrained_settings + +logger = logging.getLogger("sam") +logger.setLevel(logging.WARNING) +stream = logging.StreamHandler() +logger.addHandler(stream) +logger.propagate = False class SAM(SegmentationModel): @@ -52,13 +60,14 @@ def __init__( self, encoder_name: str = "sam-vit_h", encoder_depth: int = None, - encoder_weights: Optional[str] = "sam-vit_h", + encoder_weights: Optional[str] = None, decoder_channels: List[int] = 256, decoder_multimask_output: bool = True, in_channels: int = 3, image_size: int = 1024, vit_patch_size: int = 16, classes: int = 1, + weights: Optional[str] = "sa-1b", activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, ): @@ -99,6 +108,9 @@ def __init__( ) self._decoder_multiclass_output = decoder_multimask_output + if weights is not None: + self._load_pretrained_weights(encoder_name, weights) + self.segmentation_head = SegmentationHead( in_channels=3 if decoder_multimask_output else 1, out_channels=classes, @@ -113,6 +125,19 @@ def __init__( self.name = encoder_name self.initialize() + def _load_pretrained_weights(self, encoder_name: str, weights: str): + settings = get_pretrained_settings(sam_vit_encoders, encoder_name, weights) + state_dict = model_zoo.load_url(settings["url"]) + state_dict = {k.replace("image_encoder", "encoder"): v for k, v in state_dict.items()} + state_dict = {k.replace("mask_decoder", "decoder"): v for k, v in state_dict.items()} + missing, unused = self.load_state_dict(state_dict, strict=False) + if len(missing) > 0 or len(unused) > 0: + n_loaded = len(state_dict) - len(missing) - len(unused) + logger.warning( + f"Only {n_loaded} out of pretrained {len(state_dict)} SAM modules are loaded. " + f"Missing modules: {missing}. Unused modules: {unused}." + ) + def preprocess(self, x): """Normalize pixel values and pad to a square input.""" # Normalize colors diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index a2a23983..ba378444 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -50,6 +50,33 @@ encoders.update(sam_vit_encoders) +def get_pretrained_settings(encoders: dict, encoder_name: str, weights: str) -> dict: + """Get pretrained settings for encoder from encoders collection. + + Args: + encoders: collection of encoders + encoder_name: name of encoder in collection + weights: one of ``None`` (random initialization), ``imagenet`` or other pretrained settings + + Returns: + pretrained settings for encoder + + Raises: + KeyError: in case of wrong encoder name or pretrained settings name + """ + try: + settings = encoders[encoder_name]["pretrained_settings"][weights] + except KeyError: + raise KeyError( + "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format( + weights, + encoder_name, + list(encoders[encoder_name]["pretrained_settings"].keys()), + ) + ) + return settings + + def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): if name.startswith("tu-"): @@ -80,16 +107,7 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, ** encoder = Encoder(**params) if weights is not None: - try: - settings = encoders[name]["pretrained_settings"][weights] - except KeyError: - raise KeyError( - "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format( - weights, - name, - list(encoders[name]["pretrained_settings"].keys()), - ) - ) + settings = get_pretrained_settings(encoders, name, weights) encoder.load_state_dict(model_zoo.load_url(settings["url"])) encoder.set_in_channels(in_channels, pretrained=weights is not None) diff --git a/segmentation_models_pytorch/encoders/sam.py b/segmentation_models_pytorch/encoders/sam.py index 8d5438a6..4518111b 100644 --- a/segmentation_models_pytorch/encoders/sam.py +++ b/segmentation_models_pytorch/encoders/sam.py @@ -14,7 +14,7 @@ def __init__(self, name: str, **kwargs): "sam-vit_h": { "encoder": SamVitEncoder, "pretrained_settings": { - "sam-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"}, + "sa-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"}, }, "params": dict( embed_dim=1280, @@ -26,7 +26,7 @@ def __init__(self, name: str, **kwargs): "sam-vit_l": { "encoder": SamVitEncoder, "pretrained_settings": { - "sam-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth"}, + "sa-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth"}, }, "params": dict( embed_dim=1024, @@ -38,7 +38,7 @@ def __init__(self, name: str, **kwargs): "sam-vit_b": { "encoder": SamVitEncoder, "pretrained_settings": { - "sam-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"}, + "sa-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"}, }, "params": dict( embed_dim=768, diff --git a/tests/test_sam.py b/tests/test_sam.py index 1d288296..74fa5bc7 100644 --- a/tests/test_sam.py +++ b/tests/test_sam.py @@ -38,3 +38,8 @@ def test_sam(decoder_multiclass_output, n_classes): _test_forward(model, sample, test_shape=True) _test_forward_backward(model, sample, test_shape=True) + + +@pytest.mark.skip(reason="Run this test manually as it needs to download weights") +def test_sam_weights(): + smp.create_model("sam", encoder_name="sam-vit_b", encoder_weights=None, weights="sa-1b") From 6b369275e2c1e22a9f1314082116c324800b1aac Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Thu, 4 May 2023 14:07:40 +0400 Subject: [PATCH 08/22] load pretrained sam state dict for a model --- segmentation_models_pytorch/decoders/sam/model.py | 4 ++-- tests/test_sam.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index 953247c4..39afb390 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -73,8 +73,8 @@ def __init__( ): super().__init__() - self.pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) - self.pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) + self.register_buffer("pixel_mean", torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), False) self.encoder = get_encoder( encoder_name, diff --git a/tests/test_sam.py b/tests/test_sam.py index 74fa5bc7..10a51791 100644 --- a/tests/test_sam.py +++ b/tests/test_sam.py @@ -29,6 +29,7 @@ def test_sam(decoder_multiclass_output, n_classes): model = smp.SAM( "sam-vit_b", encoder_weights=None, + weights=None, image_size=64, decoder_multimask_output=decoder_multiclass_output, classes=n_classes, From 64a25165c8d8e8598fcf1dbf3d1cc564ea7830c9 Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Thu, 4 May 2023 17:20:40 +0400 Subject: [PATCH 09/22] update readme with sam model and encoders --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index f1e51b37..d6fb4eb9 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ Congratulations! You are done! Now you can train your model with your favorite f - PAN [[paper](https://arxiv.org/abs/1805.10180)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pan)] - DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)] [[docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3)] - DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)] + - SAM [[paper](https://ai.facebook.com/research/publications/segment-anything/)] [[docs](https://github.com/facebookresearch/segment-anything)] #### Encoders @@ -394,6 +395,19 @@ Note: In the official github repo the s0 variant has additional num_conv_branche +
+SAM +
+ +| Encoder | Weights | Params, M | +|-----------|:--------:|:---------:| +| sam-vit_b | sa-1b | 91M | +| sam-vit_l | sa-1b | 308M | +| sam-vit_h | sa-1b | 636M | + +
+
+ \* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)). From 4d1144e88beec27a5e4aee5c22dfddfe7afb33fc Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Mon, 8 May 2023 16:41:11 +0400 Subject: [PATCH 10/22] use iou scaling to avoid errors with torch ddp --- segmentation_models_pytorch/decoders/sam/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index 39afb390..cc6eadf1 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -194,5 +194,7 @@ def forward(self, x): multimask_output=self._decoder_multiclass_output, ) masks = self.postprocess_masks(low_res_masks, input_size=img_size, original_size=img_size) + # use scaling below in order to make it work with torch DDP + masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1) output = self.segmentation_head(masks) return output From c1a93198bf7df380f3d191e6c6ce01b73ebf9c14 Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Tue, 9 May 2023 10:08:28 +0400 Subject: [PATCH 11/22] set unused sam modules to require grad False --- segmentation_models_pytorch/decoders/sam/model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index cc6eadf1..fec4f313 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -198,3 +198,8 @@ def forward(self, x): masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1) output = self.segmentation_head(masks) return output + + def train(self, mode: bool = True): + super(SAM, self).train(mode) + self.prompt_encoder.point_embeddings.requires_grad = False + self.prompt_encoder.mask_downscaling.requires_grad = False From 2ed775d5df7e5142d4ee0c2afd3cbb1c888289f4 Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Tue, 9 May 2023 10:22:57 +0400 Subject: [PATCH 12/22] set unused sam modules to None --- segmentation_models_pytorch/decoders/sam/model.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index fec4f313..4091c282 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -93,6 +93,9 @@ def __init__( input_image_size=(image_size, image_size), mask_in_chans=16, ) + self.prompt_encoder.point_embeddings = None + self.prompt_encoder.mask_downscaling = None + self.not_a_point_embed = None self.decoder = MaskDecoder( num_multimask_outputs=3, @@ -198,8 +201,3 @@ def forward(self, x): masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1) output = self.segmentation_head(masks) return output - - def train(self, mode: bool = True): - super(SAM, self).train(mode) - self.prompt_encoder.point_embeddings.requires_grad = False - self.prompt_encoder.mask_downscaling.requires_grad = False From 9c93eb435f69b6780da3a7258a08774f4187c00c Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Tue, 9 May 2023 11:38:11 +0400 Subject: [PATCH 13/22] remove prompt encoder from sam --- .../decoders/sam/model.py | 32 ++++++++++++------- segmentation_models_pytorch/encoders/sam.py | 5 +++ 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index 4091c282..c7d19a41 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -3,6 +3,8 @@ import torch from segment_anything.modeling import MaskDecoder, TwoWayTransformer, PromptEncoder +from segment_anything.modeling.prompt_encoder import PositionEmbeddingRandom +from torch import nn from torch.nn import functional as F from torch.utils import model_zoo @@ -86,16 +88,12 @@ def __init__( out_chans=decoder_channels, ) + # this params are used instead of prompt_encoder image_embedding_size = image_size // vit_patch_size - self.prompt_encoder = PromptEncoder( - embed_dim=decoder_channels, - image_embedding_size=(image_embedding_size, image_embedding_size), - input_image_size=(image_size, image_size), - mask_in_chans=16, - ) - self.prompt_encoder.point_embeddings = None - self.prompt_encoder.mask_downscaling = None - self.not_a_point_embed = None + self.embed_dim = decoder_channels + self.image_embedding_size = (image_embedding_size, image_embedding_size) + self.pe_layer = PositionEmbeddingRandom(decoder_channels // 2) + self.no_mask_embed = nn.Embedding(1, decoder_channels) self.decoder = MaskDecoder( num_multimask_outputs=3, @@ -188,10 +186,11 @@ def forward(self, x): img_size = x.shape[-2:] x = torch.stack([self.preprocess(img) for img in x]) features = self.encoder(x) - sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None) + # sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None) + sparse_embeddings, dense_embeddings = self._get_dummy_promp_encoder_output(x.size(0)) low_res_masks, iou_predictions = self.decoder( image_embeddings=features, - image_pe=self.prompt_encoder.get_dense_pe(), + image_pe=self._get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=self._decoder_multiclass_output, @@ -201,3 +200,14 @@ def forward(self, x): masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1) output = self.segmentation_head(masks) return output + + def _get_dummy_promp_encoder_output(self, bs): + """Use this dummy output as we're training without prompts.""" + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self.no_mask_embed.weight.device) + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + return sparse_embeddings, dense_embeddings + + def _get_dense_pe(self): + return self.pe_layer(self.image_embedding_size).unsqueeze(0) diff --git a/segmentation_models_pytorch/encoders/sam.py b/segmentation_models_pytorch/encoders/sam.py index 4518111b..af86bfae 100644 --- a/segmentation_models_pytorch/encoders/sam.py +++ b/segmentation_models_pytorch/encoders/sam.py @@ -8,6 +8,11 @@ def __init__(self, name: str, **kwargs): super().__init__(**kwargs) self._name = name self._depth = kwargs["depth"] + self._out_chans = kwargs.get("out_chans", 256) + + @property + def out_channels(self): + return [-1, self._out_chans] sam_vit_encoders = { From 500779e6340055f2400159d4cecad16dda188d2f Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Thu, 18 May 2023 09:03:29 +0400 Subject: [PATCH 14/22] add segment-anything to reqs --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 9dc118f2..0807c07f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ torchvision>=0.5.0 pretrainedmodels==0.7.4 efficientnet-pytorch==0.7.1 timm==0.6.13 +git+https://github.com/facebookresearch/segment-anything.git tqdm pillow From b301d306c61f28da15bd8935e0f163ec0e008b8a Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Fri, 19 May 2023 09:28:31 +0400 Subject: [PATCH 15/22] integrate sam encoder to Unet model --- .../decoders/pan/model.py | 5 +++- .../decoders/sam/model.py | 3 +-- .../decoders/unet/model.py | 12 ++++++++- segmentation_models_pytorch/encoders/sam.py | 27 ++++++++++++++++++- tests/test_sam.py | 17 +++++++++++- 5 files changed, 58 insertions(+), 6 deletions(-) diff --git a/segmentation_models_pytorch/decoders/pan/model.py b/segmentation_models_pytorch/decoders/pan/model.py index 838d3e85..e618977a 100644 --- a/segmentation_models_pytorch/decoders/pan/model.py +++ b/segmentation_models_pytorch/decoders/pan/model.py @@ -58,6 +58,8 @@ def __init__( activation: Optional[Union[str, callable]] = None, upsampling: int = 4, aux_params: Optional[dict] = None, + encoder_kwargs: Optional[dict] = None, + encoder_depth: int = 5, ): super().__init__() @@ -67,9 +69,10 @@ def __init__( self.encoder = get_encoder( encoder_name, in_channels=in_channels, - depth=5, + depth=encoder_depth, weights=encoder_weights, output_stride=encoder_output_stride, + **({} if encoder_kwargs is None else encoder_kwargs), ) self.decoder = PANDecoder( diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index c7d19a41..ba9e35f5 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -185,8 +185,7 @@ def postprocess_masks( def forward(self, x): img_size = x.shape[-2:] x = torch.stack([self.preprocess(img) for img in x]) - features = self.encoder(x) - # sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None) + *_, features = self.encoder(x) sparse_embeddings, dense_embeddings = self._get_dummy_promp_encoder_output(x.size(0)) low_res_masks, iou_predictions = self.decoder( image_embeddings=features, diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 5baf043f..38732ebc 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -65,20 +65,30 @@ def __init__( classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, + encoder_kwargs: Optional[dict] = None, ): super().__init__() + # if sam encoder, make sure to make num_hidden_skips is set + if encoder_name.startswith("sam-"): + encoder_kwargs = encoder_kwargs if encoder_kwargs is not None else {} + encoder_kwargs.update({"num_hidden_skips": len(decoder_channels)}) + n_decoder_blocks = len(decoder_channels) + else: + n_decoder_blocks = encoder_depth + self.encoder = get_encoder( encoder_name, in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, + **encoder_kwargs if encoder_kwargs is not None else {}, ) self.decoder = UnetDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, - n_blocks=encoder_depth, + n_blocks=n_decoder_blocks, use_batchnorm=decoder_use_batchnorm, center=True if encoder_name.startswith("vgg") else False, attention_type=decoder_attention_type, diff --git a/segmentation_models_pytorch/encoders/sam.py b/segmentation_models_pytorch/encoders/sam.py index af86bfae..3b309ef7 100644 --- a/segmentation_models_pytorch/encoders/sam.py +++ b/segmentation_models_pytorch/encoders/sam.py @@ -1,3 +1,6 @@ +import math + +import torch from segment_anything.modeling import ImageEncoderViT from segmentation_models_pytorch.encoders._base import EncoderMixin @@ -5,14 +8,36 @@ class SamVitEncoder(EncoderMixin, ImageEncoderViT): def __init__(self, name: str, **kwargs): + patch_size = kwargs.get("patch_size", 16) + n_skips = kwargs.pop("num_hidden_skips", int(self._get_scale_factor(patch_size))) super().__init__(**kwargs) self._name = name self._depth = kwargs["depth"] self._out_chans = kwargs.get("out_chans", 256) + self._num_skips = n_skips + self._validate_output(patch_size) + + @staticmethod + def _get_scale_factor(patch_size: int) -> float: + """Input image will be downscale by this factor""" + return math.log(patch_size, 2) + + def _validate_output(self, patch_size: int): + scale_factor = self._get_scale_factor(patch_size) + if scale_factor != self._num_skips: + raise ValueError( + f"With {patch_size=} and {self._num_skips} skip connection layers, " + "spatial dimensions of model output will not match input spatial dimensions" + ) @property def out_channels(self): - return [-1, self._out_chans] + # Fill up with leading zeros to be used in Unet + return [0] * self._num_skips + [self._out_chans] + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + # Return a list of tensors to match other encoders + return [x, super().forward(x)] sam_vit_encoders = { diff --git a/tests/test_sam.py b/tests/test_sam.py index 10a51791..8e8efac3 100644 --- a/tests/test_sam.py +++ b/tests/test_sam.py @@ -20,7 +20,7 @@ def test_sam_encoder(encoder_name, img_size, patch_size, depth): out = encoder(sample) expected_patches = img_size // patch_size - assert out.size() == torch.Size([1, 256, expected_patches, expected_patches]) + assert out[-1].size() == torch.Size([1, 256, expected_patches, expected_patches]) @pytest.mark.parametrize("decoder_multiclass_output", [True, False]) @@ -41,6 +41,21 @@ def test_sam(decoder_multiclass_output, n_classes): _test_forward_backward(model, sample, test_shape=True) +@pytest.mark.parametrize("model_class", [smp.Unet]) +@pytest.mark.parametrize("decoder_channels,patch_size", [([64, 32, 16, 8], 16), ([64, 32, 16], 8)]) +def test_sam_as_encoder_only(model_class, decoder_channels, patch_size): + img_size = 64 + model = model_class( + "sam-vit_b", + encoder_weights=None, + encoder_depth=3, + encoder_kwargs=dict(img_size=img_size, out_chans=decoder_channels[0], patch_size=patch_size), + decoder_channels=decoder_channels, + ) + smp = torch.ones(1, 3, img_size, img_size) + _test_forward_backward(model, smp, test_shape=True) + + @pytest.mark.skip(reason="Run this test manually as it needs to download weights") def test_sam_weights(): smp.create_model("sam", encoder_name="sam-vit_b", encoder_weights=None, weights="sa-1b") From 12a0db650b1da49f885c79b6ef9dcaf081e5e51d Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Fri, 19 May 2023 10:21:38 +0400 Subject: [PATCH 16/22] ensure sam encoder weights loading --- .../decoders/sam/model.py | 9 ++------- segmentation_models_pytorch/encoders/sam.py | 18 ++++++++++++++++++ tests/test_sam.py | 7 +++++++ 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py index ba9e35f5..59c68936 100644 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ b/segmentation_models_pytorch/decoders/sam/model.py @@ -1,4 +1,5 @@ import logging +import warnings from typing import Optional, Union, List, Tuple import torch @@ -14,12 +15,6 @@ ) from segmentation_models_pytorch.encoders import get_encoder, sam_vit_encoders, get_pretrained_settings -logger = logging.getLogger("sam") -logger.setLevel(logging.WARNING) -stream = logging.StreamHandler() -logger.addHandler(stream) -logger.propagate = False - class SAM(SegmentationModel): """SAM_ (Segment Anything Model) is a visual transformer based encoder-decoder segmentation @@ -134,7 +129,7 @@ def _load_pretrained_weights(self, encoder_name: str, weights: str): missing, unused = self.load_state_dict(state_dict, strict=False) if len(missing) > 0 or len(unused) > 0: n_loaded = len(state_dict) - len(missing) - len(unused) - logger.warning( + warnings.warn( f"Only {n_loaded} out of pretrained {len(state_dict)} SAM modules are loaded. " f"Missing modules: {missing}. Unused modules: {unused}." ) diff --git a/segmentation_models_pytorch/encoders/sam.py b/segmentation_models_pytorch/encoders/sam.py index 3b309ef7..cd37d57b 100644 --- a/segmentation_models_pytorch/encoders/sam.py +++ b/segmentation_models_pytorch/encoders/sam.py @@ -1,4 +1,6 @@ import math +import warnings +from typing import Mapping, Any import torch from segment_anything.modeling import ImageEncoderViT @@ -39,6 +41,22 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: # Return a list of tensors to match other encoders return [x, super().forward(x)] + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) -> None: + # Exclude mask_decoder and prompt encoder weights + # and remove 'image_encoder.' prefix + state_dict = { + k.replace("image_encoder.", ""): v + for k, v in state_dict.items() + if not k.startswith("mask_decoder") and not k.startswith("prompt_encoder") + } + missing, unused = super().load_state_dict(state_dict, strict=False) + if len(missing) + len(unused) > 0: + n_loaded = len(state_dict) - len(missing) - len(unused) + warnings.warn( + f"Only {n_loaded} out of pretrained {len(state_dict)} SAM image encoder modules are loaded. " + f"Missing modules: {missing}. Unused modules: {unused}." + ) + sam_vit_encoders = { "sam-vit_h": { diff --git a/tests/test_sam.py b/tests/test_sam.py index 8e8efac3..254b6ab1 100644 --- a/tests/test_sam.py +++ b/tests/test_sam.py @@ -59,3 +59,10 @@ def test_sam_as_encoder_only(model_class, decoder_channels, patch_size): @pytest.mark.skip(reason="Run this test manually as it needs to download weights") def test_sam_weights(): smp.create_model("sam", encoder_name="sam-vit_b", encoder_weights=None, weights="sa-1b") + + +# @pytest.mark.skip(reason="Run this test manually as it needs to download weights") +def test_sam_encoder_weights(): + smp.create_model( + "unet", encoder_name="sam-vit_b", encoder_weights="sa-1b", encoder_depth=12, decoder_channels=[64, 32, 16, 8] + ) From a049c88a5f5135f16164e1fe25084425ed7f8850 Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Fri, 19 May 2023 10:40:40 +0400 Subject: [PATCH 17/22] update segment-anything package source in reqs --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0807c07f..9d7fff6f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ torchvision>=0.5.0 pretrainedmodels==0.7.4 efficientnet-pytorch==0.7.1 timm==0.6.13 -git+https://github.com/facebookresearch/segment-anything.git +segment-anything-py tqdm pillow From e6cfdc965c966c0958822b021c8ae4bced4fbd8a Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Thu, 25 May 2023 10:01:10 +0400 Subject: [PATCH 18/22] remove sam decoder as it's not stable yet --- segmentation_models_pytorch/__init__.py | 2 - .../decoders/sam/__init__.py | 1 - .../decoders/sam/model.py | 207 ------------------ tests/test_models.py | 2 +- tests/test_sam.py | 3 +- 5 files changed, 3 insertions(+), 212 deletions(-) delete mode 100644 segmentation_models_pytorch/decoders/sam/__init__.py delete mode 100644 segmentation_models_pytorch/decoders/sam/model.py diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index ccf56715..1ac9e1fb 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -12,7 +12,6 @@ from .decoders.pspnet import PSPNet from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus from .decoders.pan import PAN -from .decoders.sam import SAM from .__version__ import __version__ @@ -43,7 +42,6 @@ def create_model( DeepLabV3, DeepLabV3Plus, PAN, - SAM, ] archs_dict = {a.__name__.lower(): a for a in archs} try: diff --git a/segmentation_models_pytorch/decoders/sam/__init__.py b/segmentation_models_pytorch/decoders/sam/__init__.py deleted file mode 100644 index dfc17c68..00000000 --- a/segmentation_models_pytorch/decoders/sam/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .model import SAM diff --git a/segmentation_models_pytorch/decoders/sam/model.py b/segmentation_models_pytorch/decoders/sam/model.py deleted file mode 100644 index 59c68936..00000000 --- a/segmentation_models_pytorch/decoders/sam/model.py +++ /dev/null @@ -1,207 +0,0 @@ -import logging -import warnings -from typing import Optional, Union, List, Tuple - -import torch -from segment_anything.modeling import MaskDecoder, TwoWayTransformer, PromptEncoder -from segment_anything.modeling.prompt_encoder import PositionEmbeddingRandom -from torch import nn -from torch.nn import functional as F -from torch.utils import model_zoo - -from segmentation_models_pytorch.base import ( - SegmentationModel, - SegmentationHead, -) -from segmentation_models_pytorch.encoders import get_encoder, sam_vit_encoders, get_pretrained_settings - - -class SAM(SegmentationModel): - """SAM_ (Segment Anything Model) is a visual transformer based encoder-decoder segmentation - model that can be used to produce high quality segmentation masks from images and prompts. - Consists of *image encoder*, *prompt encoder* and *mask decoder*. *Segmentation head* is - added after the *mask decoder* to define the final number of classes for the output mask. - - Args: - encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) - to extract features of different spatial resolution - encoder_depth: A number of stages used in encoder in range [6, 24]. Each stage generate features - two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features - with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). - Default is 5 - encoder_weights: One of **None** (random initialization), **"sa-1b"** (pre-training on SA-1B dataset). - decoder_channels: How many output channels image encoder will have. Default is 256. - in_channels: A number of input channels for the model, default is 3 (RGB images) - classes: A number of classes for output mask (or you can think as a number of channels of output mask) - activation: An activation function to apply after the final convolution layer. - Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, - **callable** and **None**. - Default is **None** - aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build - on top of encoder if **aux_params** is not **None** (default). Supported params: - - classes (int): A number of classes - - pooling (str): One of "max", "avg". Default is "avg" - - dropout (float): Dropout factor in [0, 1) - - activation (str): An activation function to apply "sigmoid"/"softmax" - (could be **None** to return logits) - - Returns: - ``torch.nn.Module``: SAM - - .. _SAM: - https://github.com/facebookresearch/segment-anything - - """ - - def __init__( - self, - encoder_name: str = "sam-vit_h", - encoder_depth: int = None, - encoder_weights: Optional[str] = None, - decoder_channels: List[int] = 256, - decoder_multimask_output: bool = True, - in_channels: int = 3, - image_size: int = 1024, - vit_patch_size: int = 16, - classes: int = 1, - weights: Optional[str] = "sa-1b", - activation: Optional[Union[str, callable]] = None, - aux_params: Optional[dict] = None, - ): - super().__init__() - - self.register_buffer("pixel_mean", torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), False) - self.register_buffer("pixel_std", torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), False) - - self.encoder = get_encoder( - encoder_name, - in_channels=in_channels, - depth=encoder_depth, - weights=encoder_weights, - img_size=image_size, - patch_size=vit_patch_size, - out_chans=decoder_channels, - ) - - # this params are used instead of prompt_encoder - image_embedding_size = image_size // vit_patch_size - self.embed_dim = decoder_channels - self.image_embedding_size = (image_embedding_size, image_embedding_size) - self.pe_layer = PositionEmbeddingRandom(decoder_channels // 2) - self.no_mask_embed = nn.Embedding(1, decoder_channels) - - self.decoder = MaskDecoder( - num_multimask_outputs=3, - transformer=TwoWayTransformer( - depth=2, - embedding_dim=decoder_channels, - mlp_dim=2048, - num_heads=8, - ), - transformer_dim=decoder_channels, - iou_head_depth=3, - iou_head_hidden_dim=256, - ) - self._decoder_multiclass_output = decoder_multimask_output - - if weights is not None: - self._load_pretrained_weights(encoder_name, weights) - - self.segmentation_head = SegmentationHead( - in_channels=3 if decoder_multimask_output else 1, - out_channels=classes, - activation=activation, - kernel_size=3, - ) - - if aux_params is not None: - raise NotImplementedError("Auxiliary output is not supported yet") - self.classification_head = None - - self.name = encoder_name - self.initialize() - - def _load_pretrained_weights(self, encoder_name: str, weights: str): - settings = get_pretrained_settings(sam_vit_encoders, encoder_name, weights) - state_dict = model_zoo.load_url(settings["url"]) - state_dict = {k.replace("image_encoder", "encoder"): v for k, v in state_dict.items()} - state_dict = {k.replace("mask_decoder", "decoder"): v for k, v in state_dict.items()} - missing, unused = self.load_state_dict(state_dict, strict=False) - if len(missing) > 0 or len(unused) > 0: - n_loaded = len(state_dict) - len(missing) - len(unused) - warnings.warn( - f"Only {n_loaded} out of pretrained {len(state_dict)} SAM modules are loaded. " - f"Missing modules: {missing}. Unused modules: {unused}." - ) - - def preprocess(self, x): - """Normalize pixel values and pad to a square input.""" - # Normalize colors - x = (x - self.pixel_mean) / self.pixel_std - - # Pad - h, w = x.shape[-2:] - padh = self.encoder.img_size - h - padw = self.encoder.img_size - w - x = F.pad(x, (0, padw, 0, padh)) - return x - - def postprocess_masks( - self, - masks: torch.Tensor, - input_size: Tuple[int, ...], - original_size: Tuple[int, ...], - ) -> torch.Tensor: - """ - Remove padding and upscale masks to the original image size. - - Arguments: - masks (torch.Tensor): Batched masks from the mask_decoder, - in BxCxHxW format. - input_size (tuple(int, int)): The size of the image input to the - model, in (H, W) format. Used to remove padding. - original_size (tuple(int, int)): The original size of the image - before resizing for input to the model, in (H, W) format. - - Returns: - (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) - is given by original_size. - """ - masks = F.interpolate( - masks, - (self.encoder.img_size, self.encoder.img_size), - mode="bilinear", - align_corners=False, - ) - masks = masks[..., : input_size[0], : input_size[1]] - masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) - return masks - - def forward(self, x): - img_size = x.shape[-2:] - x = torch.stack([self.preprocess(img) for img in x]) - *_, features = self.encoder(x) - sparse_embeddings, dense_embeddings = self._get_dummy_promp_encoder_output(x.size(0)) - low_res_masks, iou_predictions = self.decoder( - image_embeddings=features, - image_pe=self._get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=self._decoder_multiclass_output, - ) - masks = self.postprocess_masks(low_res_masks, input_size=img_size, original_size=img_size) - # use scaling below in order to make it work with torch DDP - masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1) - output = self.segmentation_head(masks) - return output - - def _get_dummy_promp_encoder_output(self, bs): - """Use this dummy output as we're training without prompts.""" - sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self.no_mask_embed.weight.device) - dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( - bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] - ) - return sparse_embeddings, dense_embeddings - - def _get_dense_pe(self): - return self.pe_layer(self.image_embedding_size).unsqueeze(0) diff --git a/tests/test_models.py b/tests/test_models.py index c5dada3f..08e87b91 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -29,7 +29,7 @@ def get_encoders(): def get_sample(model_class): - if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus, smp.MAnet, smp.SAM]: + if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus, smp.MAnet]: sample = torch.ones([1, 3, 64, 64]) elif model_class == smp.PAN: sample = torch.ones([2, 3, 256, 256]) diff --git a/tests/test_sam.py b/tests/test_sam.py index 254b6ab1..4b936754 100644 --- a/tests/test_sam.py +++ b/tests/test_sam.py @@ -23,6 +23,7 @@ def test_sam_encoder(encoder_name, img_size, patch_size, depth): assert out[-1].size() == torch.Size([1, 256, expected_patches, expected_patches]) +@pytest.mark.skip(reason="Decoder has been removed, keeping this for future integration") @pytest.mark.parametrize("decoder_multiclass_output", [True, False]) @pytest.mark.parametrize("n_classes", [1, 3]) def test_sam(decoder_multiclass_output, n_classes): @@ -61,7 +62,7 @@ def test_sam_weights(): smp.create_model("sam", encoder_name="sam-vit_b", encoder_weights=None, weights="sa-1b") -# @pytest.mark.skip(reason="Run this test manually as it needs to download weights") +@pytest.mark.skip(reason="Run this test manually as it needs to download weights") def test_sam_encoder_weights(): smp.create_model( "unet", encoder_name="sam-vit_b", encoder_weights="sa-1b", encoder_depth=12, decoder_channels=[64, 32, 16, 8] From 9b291249e2a342eeab92be1a09bf7bc222250111 Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Fri, 26 May 2023 11:17:36 +0400 Subject: [PATCH 19/22] minor changes from PR review --- README.md | 1 - docs/models.rst | 5 ----- requirements.txt | 2 +- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/README.md b/README.md index d6fb4eb9..6925c729 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,6 @@ Congratulations! You are done! Now you can train your model with your favorite f - PAN [[paper](https://arxiv.org/abs/1805.10180)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pan)] - DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)] [[docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3)] - DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)] - - SAM [[paper](https://ai.facebook.com/research/publications/segment-anything/)] [[docs](https://github.com/facebookresearch/segment-anything)] #### Encoders diff --git a/docs/models.rst b/docs/models.rst index 06c1f3ce..a5ab52c1 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -36,8 +36,3 @@ DeepLabV3 DeepLabV3+ ~~~~~~~~~~ .. autoclass:: segmentation_models_pytorch.DeepLabV3Plus - -SAM -~~~~~~~~~~ -.. autoclass:: segmentation_models_pytorch.SAM - diff --git a/requirements.txt b/requirements.txt index 9d7fff6f..fec031ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ torchvision>=0.5.0 pretrainedmodels==0.7.4 efficientnet-pytorch==0.7.1 timm==0.6.13 -segment-anything-py +segment-anything-py==1.0 tqdm pillow From 5edc0ee9452b5e3b078583f20bfa04141418adc7 Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Fri, 2 Jun 2023 12:33:44 +0400 Subject: [PATCH 20/22] use vit_depth to control sam vit depth --- .../decoders/unet/model.py | 12 +---- .../encoders/__init__.py | 9 +--- segmentation_models_pytorch/encoders/sam.py | 46 +++++++++++-------- tests/test_sam.py | 27 ++++++----- 4 files changed, 46 insertions(+), 48 deletions(-) diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 38732ebc..5baf043f 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -65,30 +65,20 @@ def __init__( classes: int = 1, activation: Optional[Union[str, callable]] = None, aux_params: Optional[dict] = None, - encoder_kwargs: Optional[dict] = None, ): super().__init__() - # if sam encoder, make sure to make num_hidden_skips is set - if encoder_name.startswith("sam-"): - encoder_kwargs = encoder_kwargs if encoder_kwargs is not None else {} - encoder_kwargs.update({"num_hidden_skips": len(decoder_channels)}) - n_decoder_blocks = len(decoder_channels) - else: - n_decoder_blocks = encoder_depth - self.encoder = get_encoder( encoder_name, in_channels=in_channels, depth=encoder_depth, weights=encoder_weights, - **encoder_kwargs if encoder_kwargs is not None else {}, ) self.decoder = UnetDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, - n_blocks=n_decoder_blocks, + n_blocks=encoder_depth, use_batchnorm=decoder_use_batchnorm, center=True if encoder_name.startswith("vgg") else False, attention_type=decoder_attention_type, diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index 0656c83f..635f44b4 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -97,13 +97,8 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, ** raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys()))) params = encoders[name]["params"] - if name.startswith("sam-"): - params.update(**kwargs) - params.update(dict(name=name[4:])) - if depth is not None: - params.update(depth=depth) - else: - params.update(depth=depth) + params.update(depth=depth) + params.update(kwargs) encoder = Encoder(**params) if weights is not None: diff --git a/segmentation_models_pytorch/encoders/sam.py b/segmentation_models_pytorch/encoders/sam.py index cd37d57b..08aa9435 100644 --- a/segmentation_models_pytorch/encoders/sam.py +++ b/segmentation_models_pytorch/encoders/sam.py @@ -9,33 +9,41 @@ class SamVitEncoder(EncoderMixin, ImageEncoderViT): - def __init__(self, name: str, **kwargs): - patch_size = kwargs.get("patch_size", 16) - n_skips = kwargs.pop("num_hidden_skips", int(self._get_scale_factor(patch_size))) + def __init__(self, **kwargs): + self._vit_depth = kwargs.pop("vit_depth") + self._encoder_depth = kwargs.get("depth", 5) + kwargs.update({"depth": self._vit_depth}) super().__init__(**kwargs) - self._name = name - self._depth = kwargs["depth"] self._out_chans = kwargs.get("out_chans", 256) - self._num_skips = n_skips - self._validate_output(patch_size) + self._patch_size = kwargs.get("patch_size", 16) + self._validate() - @staticmethod - def _get_scale_factor(patch_size: int) -> float: + @property + def output_stride(self): + return 32 + + def _get_scale_factor(self) -> float: """Input image will be downscale by this factor""" - return math.log(patch_size, 2) + return int(math.log(self._patch_size, 2)) - def _validate_output(self, patch_size: int): - scale_factor = self._get_scale_factor(patch_size) - if scale_factor != self._num_skips: + def _validate(self): + # check vit depth + if self._vit_depth not in [12, 24, 32]: + raise ValueError(f"vit_depth must be one of [12, 24, 32], got {self._vit_depth}") + # check output + scale_factor = self._get_scale_factor() + if scale_factor != self._encoder_depth: raise ValueError( - f"With {patch_size=} and {self._num_skips} skip connection layers, " - "spatial dimensions of model output will not match input spatial dimensions" + f"With patch_size={self._patch_size} and depth={self._encoder_depth}, " + "spatial dimensions of model output will not match input spatial dimensions. " + "It is recommended to set encoder depth=4 with default vit patch_size=16." ) @property def out_channels(self): # Fill up with leading zeros to be used in Unet - return [0] * self._num_skips + [self._out_chans] + scale_factor = self._get_scale_factor() + return [0] * scale_factor + [self._out_chans] def forward(self, x: torch.Tensor) -> list[torch.Tensor]: # Return a list of tensors to match other encoders @@ -66,7 +74,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) -> }, "params": dict( embed_dim=1280, - depth=32, + vit_depth=32, num_heads=16, global_attn_indexes=[7, 15, 23, 31], ), @@ -78,7 +86,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) -> }, "params": dict( embed_dim=1024, - depth=24, + vit_depth=24, num_heads=16, global_attn_indexes=[5, 11, 17, 23], ), @@ -90,7 +98,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) -> }, "params": dict( embed_dim=768, - depth=12, + vit_depth=12, num_heads=12, global_attn_indexes=[2, 5, 8, 11], ), diff --git a/tests/test_sam.py b/tests/test_sam.py index 4b936754..8168bb2f 100644 --- a/tests/test_sam.py +++ b/tests/test_sam.py @@ -8,11 +8,10 @@ @pytest.mark.parametrize("encoder_name", ["sam-vit_b", "sam-vit_l"]) @pytest.mark.parametrize("img_size", [64, 128]) -@pytest.mark.parametrize("patch_size", [8, 16]) -@pytest.mark.parametrize("depth", [6, 24, None]) -def test_sam_encoder(encoder_name, img_size, patch_size, depth): - encoder = get_encoder(encoder_name, img_size=img_size, patch_size=patch_size, depth=depth) - assert encoder._name == encoder_name[4:] +@pytest.mark.parametrize("patch_size,depth", [(8, 3), (16, 4)]) +@pytest.mark.parametrize("vit_depth", [12, 24]) +def test_sam_encoder(encoder_name, img_size, patch_size, depth, vit_depth): + encoder = get_encoder(encoder_name, img_size=img_size, patch_size=patch_size, depth=depth, vit_depth=vit_depth) assert encoder.output_stride == 32 sample = torch.ones(1, 3, img_size, img_size) @@ -23,6 +22,13 @@ def test_sam_encoder(encoder_name, img_size, patch_size, depth): assert out[-1].size() == torch.Size([1, 256, expected_patches, expected_patches]) +def test_sam_encoder_validation_error(): + with pytest.raises(ValueError): + get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=5, vit_depth=12) + get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=4, vit_depth=None) + get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=4, vit_depth=6) + + @pytest.mark.skip(reason="Decoder has been removed, keeping this for future integration") @pytest.mark.parametrize("decoder_multiclass_output", [True, False]) @pytest.mark.parametrize("n_classes", [1, 3]) @@ -43,14 +49,13 @@ def test_sam(decoder_multiclass_output, n_classes): @pytest.mark.parametrize("model_class", [smp.Unet]) -@pytest.mark.parametrize("decoder_channels,patch_size", [([64, 32, 16, 8], 16), ([64, 32, 16], 8)]) -def test_sam_as_encoder_only(model_class, decoder_channels, patch_size): - img_size = 64 +@pytest.mark.parametrize("decoder_channels,encoder_depth", [([64, 32, 16, 8], 4), ([64, 32, 16, 8], 4)]) +def test_sam_encoder_arch(model_class, decoder_channels, encoder_depth): + img_size = 1024 model = model_class( "sam-vit_b", encoder_weights=None, - encoder_depth=3, - encoder_kwargs=dict(img_size=img_size, out_chans=decoder_channels[0], patch_size=patch_size), + encoder_depth=encoder_depth, decoder_channels=decoder_channels, ) smp = torch.ones(1, 3, img_size, img_size) @@ -65,5 +70,5 @@ def test_sam_weights(): @pytest.mark.skip(reason="Run this test manually as it needs to download weights") def test_sam_encoder_weights(): smp.create_model( - "unet", encoder_name="sam-vit_b", encoder_weights="sa-1b", encoder_depth=12, decoder_channels=[64, 32, 16, 8] + "unet", encoder_name="sam-vit_b", encoder_depth=4, encoder_weights="sa-1b", decoder_channels=[64, 32, 16, 8] ) From e5c4bc4db2762ada698b33097ba5d05890258cd2 Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Mon, 5 Jun 2023 09:18:50 +0400 Subject: [PATCH 21/22] rm changes from pan model --- segmentation_models_pytorch/decoders/pan/model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/segmentation_models_pytorch/decoders/pan/model.py b/segmentation_models_pytorch/decoders/pan/model.py index e618977a..838d3e85 100644 --- a/segmentation_models_pytorch/decoders/pan/model.py +++ b/segmentation_models_pytorch/decoders/pan/model.py @@ -58,8 +58,6 @@ def __init__( activation: Optional[Union[str, callable]] = None, upsampling: int = 4, aux_params: Optional[dict] = None, - encoder_kwargs: Optional[dict] = None, - encoder_depth: int = 5, ): super().__init__() @@ -69,10 +67,9 @@ def __init__( self.encoder = get_encoder( encoder_name, in_channels=in_channels, - depth=encoder_depth, + depth=5, weights=encoder_weights, output_stride=encoder_output_stride, - **({} if encoder_kwargs is None else encoder_kwargs), ) self.decoder = PANDecoder( From c5bc3567a0d64bcde2fdb76f826834af8751a368 Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Mon, 12 Jun 2023 16:12:16 +0400 Subject: [PATCH 22/22] implement skip connections for sam vit encoder --- segmentation_models_pytorch/encoders/sam.py | 78 ++++++++++++++++++--- tests/test_sam.py | 50 +++++++------ 2 files changed, 92 insertions(+), 36 deletions(-) diff --git a/segmentation_models_pytorch/encoders/sam.py b/segmentation_models_pytorch/encoders/sam.py index 08aa9435..aac722ba 100644 --- a/segmentation_models_pytorch/encoders/sam.py +++ b/segmentation_models_pytorch/encoders/sam.py @@ -4,6 +4,8 @@ import torch from segment_anything.modeling import ImageEncoderViT +from torch import nn +from segment_anything.modeling.common import LayerNorm2d from segmentation_models_pytorch.encoders._base import EncoderMixin @@ -16,15 +18,55 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self._out_chans = kwargs.get("out_chans", 256) self._patch_size = kwargs.get("patch_size", 16) + self._embed_dim = kwargs.get("embed_dim", 768) self._validate() + self.intermediate_necks = nn.ModuleList( + [self.init_neck(self._embed_dim, out_chan) for out_chan in self.out_channels[:-1]] + ) + + @staticmethod + def init_neck(embed_dim: int, out_chans: int) -> nn.Module: + # Use similar neck as in ImageEncoderViT + return nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + @staticmethod + def neck_forward(neck: nn.Module, x: torch.Tensor, scale_factor: float = 1) -> torch.Tensor: + x = x.permute(0, 3, 1, 2) + if scale_factor != 1.0: + x = nn.functional.interpolate(x, scale_factor=scale_factor, mode="bilinear") + return neck(x) + + def requires_grad_(self, requires_grad: bool = True): + # Keep the intermediate necks trainable + for param in self.parameters(): + param.requires_grad_(requires_grad) + for param in self.intermediate_necks.parameters(): + param.requires_grad_(True) + return self @property def output_stride(self): return 32 - def _get_scale_factor(self) -> float: - """Input image will be downscale by this factor""" - return int(math.log(self._patch_size, 2)) + @property + def out_channels(self): + return [self._out_chans // (2**i) for i in range(self._encoder_depth + 1)][::-1] def _validate(self): # check vit depth @@ -39,15 +81,30 @@ def _validate(self): "It is recommended to set encoder depth=4 with default vit patch_size=16." ) - @property - def out_channels(self): - # Fill up with leading zeros to be used in Unet - scale_factor = self._get_scale_factor() - return [0] * scale_factor + [self._out_chans] + def _get_scale_factor(self) -> float: + """Input image will be downscale by this factor""" + return int(math.log(self._patch_size, 2)) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: - # Return a list of tensors to match other encoders - return [x, super().forward(x)] + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + features = [] + skip_steps = self._vit_depth // self._encoder_depth + scale_factor = self._get_scale_factor() + for i, blk in enumerate(self.blocks): + x = blk(x) + if i % skip_steps == 0: + # Double spatial dimension and halve number of channels + neck = self.intermediate_necks[i // skip_steps] + features.append(self.neck_forward(neck, x, scale_factor=2**scale_factor)) + scale_factor -= 1 + + x = self.neck(x.permute(0, 3, 1, 2)) + features.append(x) + + return features def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) -> None: # Exclude mask_decoder and prompt encoder weights @@ -58,6 +115,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) -> if not k.startswith("mask_decoder") and not k.startswith("prompt_encoder") } missing, unused = super().load_state_dict(state_dict, strict=False) + missing = list(filter(lambda x: not x.startswith("intermediate_necks"), missing)) if len(missing) + len(unused) > 0: n_loaded = len(state_dict) - len(missing) - len(unused) warnings.warn( diff --git a/tests/test_sam.py b/tests/test_sam.py index 8168bb2f..2c377fa9 100644 --- a/tests/test_sam.py +++ b/tests/test_sam.py @@ -13,13 +13,35 @@ def test_sam_encoder(encoder_name, img_size, patch_size, depth, vit_depth): encoder = get_encoder(encoder_name, img_size=img_size, patch_size=patch_size, depth=depth, vit_depth=vit_depth) assert encoder.output_stride == 32 + assert encoder.out_channels == [256 // (2**i) for i in range(depth + 1)][::-1] sample = torch.ones(1, 3, img_size, img_size) with torch.no_grad(): out = encoder(sample) - expected_patches = img_size // patch_size - assert out[-1].size() == torch.Size([1, 256, expected_patches, expected_patches]) + assert len(out) == depth + 1 + + expected_spatial_size = img_size // patch_size + expected_chans = 256 + for i in range(1, len(out)): + assert out[-i].size() == torch.Size([1, expected_chans, expected_spatial_size, expected_spatial_size]) + expected_spatial_size *= 2 + expected_chans //= 2 + + +def test_sam_encoder_trainable(): + encoder = get_encoder("sam-vit_b", depth=4) + + encoder.requires_grad_(False) + for name, param in encoder.named_parameters(): + if name.startswith("intermediate_necks"): + assert param.requires_grad + else: + assert not param.requires_grad + + encoder.requires_grad_(True) + for param in encoder.parameters(): + assert param.requires_grad def test_sam_encoder_validation_error(): @@ -29,25 +51,6 @@ def test_sam_encoder_validation_error(): get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=4, vit_depth=6) -@pytest.mark.skip(reason="Decoder has been removed, keeping this for future integration") -@pytest.mark.parametrize("decoder_multiclass_output", [True, False]) -@pytest.mark.parametrize("n_classes", [1, 3]) -def test_sam(decoder_multiclass_output, n_classes): - model = smp.SAM( - "sam-vit_b", - encoder_weights=None, - weights=None, - image_size=64, - decoder_multimask_output=decoder_multiclass_output, - classes=n_classes, - ) - sample = get_sample(smp.SAM) - model.eval() - - _test_forward(model, sample, test_shape=True) - _test_forward_backward(model, sample, test_shape=True) - - @pytest.mark.parametrize("model_class", [smp.Unet]) @pytest.mark.parametrize("decoder_channels,encoder_depth", [([64, 32, 16, 8], 4), ([64, 32, 16, 8], 4)]) def test_sam_encoder_arch(model_class, decoder_channels, encoder_depth): @@ -62,11 +65,6 @@ def test_sam_encoder_arch(model_class, decoder_channels, encoder_depth): _test_forward_backward(model, smp, test_shape=True) -@pytest.mark.skip(reason="Run this test manually as it needs to download weights") -def test_sam_weights(): - smp.create_model("sam", encoder_name="sam-vit_b", encoder_weights=None, weights="sa-1b") - - @pytest.mark.skip(reason="Run this test manually as it needs to download weights") def test_sam_encoder_weights(): smp.create_model(