Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adding DPT #1079

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Conversation

vedantdalimkar
Copy link

@vedantdalimkar vedantdalimkar commented Mar 2, 2025

Tried to address issue #1073

  1. Added a TimmViTEncoder class in the encoder package to support ViT models for encoder.
  2. Added DPT model architecture.

I have 1 concern: Right now the model only works for images having the same resolution as the original input resolution on which a particular ViT encoder was trained. Is this behaviour okay or should I change the functionality so that a dynamic image resolution is supported?

I have tested various ViT encoders at different encoder depths and the model seems to run correctly.
@qubvel Please let me know if you feel I should make any changes.

@qubvel
Copy link
Collaborator

qubvel commented Mar 3, 2025

Hi @vedantdalimkar. Thanks a lot for the PR 🤗 I am a bit too busy this week, so will review it early next week. Thanks for your patience. Meanwhile, can you please set up tests for the encoder and the DPT model? Please see how other models are tested

@vedantdalimkar
Copy link
Author

Hi @vedantdalimkar. Thanks a lot for the PR 🤗 I am a bit too busy this week, so will review it early next week. Thanks for your patience. Meanwhile, can you please set up tests for the encoder and the DPT model? Please see how other models are tested

Hi @qubvel. Sure, I will set up the relevant tests.

@vedantdalimkar
Copy link
Author

vedantdalimkar commented Mar 8, 2025

Hi @qubvel, did some refactoring. This commit should pass majority of the tests now. Had a few questions -

  1. The DPT model doesn't seem to torch compilable and scriptable since it has graph breaks. Should I skip those tests?
  2. Currently, the DPT model is not on HF hub which is required for the test_preserve_forward_output test, should I skip this test as well?

Some issues I faced with the default environment for the smp development.

  1. The TimmViTEncoder class that I have added requires the latest version of timm, so I have decorated all test functions with requires_timm_greater_or_equal function (similar to requires_torch_greater_or_equal). If possible, please use the latest timm version in the requirements so that you can run these tests or let me know if you want me to change this behaviour.
  2. The GeLU activation function used in the decoder requires torch version >= 2.0. However, the requirements have a torch version lesser than this. Can the requirements be updated?

@vedantdalimkar
Copy link
Author

Hi @qubvel. Just a gentle reminder, I think this PR is ready for review.

@qubvel
Copy link
Collaborator

qubvel commented Mar 14, 2025

Hey @vedantdalimkar, thanks for the ping, and sorry for the delay! Will try to do a pass today or on Monday. Thank you for your patience 🤗

@qubvel qubvel self-requested a review March 18, 2025 09:50
Copy link

codecov bot commented Mar 18, 2025

Codecov Report

Attention: Patch coverage is 86.86441% with 31 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
segmentation_models_pytorch/encoders/timm_vit.py 72.04% 26 Missing ⚠️
...egmentation_models_pytorch/decoders/dpt/decoder.py 95.00% 5 Missing ⚠️
Files with missing lines Coverage Δ
segmentation_models_pytorch/__init__.py 93.10% <100.00%> (+0.24%) ⬆️
...gmentation_models_pytorch/decoders/dpt/__init__.py 100.00% <100.00%> (ø)
segmentation_models_pytorch/decoders/dpt/model.py 100.00% <100.00%> (ø)
segmentation_models_pytorch/encoders/__init__.py 76.00% <100.00%> (+1.00%) ⬆️
...egmentation_models_pytorch/decoders/dpt/decoder.py 95.00% <95.00%> (ø)
segmentation_models_pytorch/encoders/timm_vit.py 72.04% <72.04%> (ø)

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines 209 to 228
def load_state_dict(self, state_dict, **kwargs):
# for compatibility of weights for
# timm- ported encoders with TimmUniversalEncoder
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]

is_deprecated_encoder = any(
self.name.startswith(pattern) for pattern in patterns
)

if is_deprecated_encoder:
keys = list(state_dict.keys())
for key in keys:
new_key = key
if not key.startswith("model."):
new_key = "model." + key
if "gernet" in self.name:
new_key = new_key.replace(".stages.", ".stages_")
state_dict[new_key] = state_dict.pop(key)

return super().load_state_dict(state_dict, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed for this kind of encoder

Comment on lines 231 to 246
def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]:
"""
Merge two dictionaries, ensuring no duplicate keys exist.

Args:
a (dict): Base dictionary.
b (dict): Additional parameters to merge.

Returns:
dict: A merged dictionary.
"""
duplicates = a.keys() & b.keys()
if duplicates:
raise ValueError(f"'{duplicates}' already specified internally")

return a | b
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be imported if needed, no need to duplicate code

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @vedantdalimkar! Great work and sorry once again for the delay 🤗 Thank you for working on the model, looks super good for the first iteration, I think we need just a few steps to make it merged 🚀

Here is what's missing:

  1. Conversion script and integration test to ensure model produces the same logits as the original one. In addition would be great to create a notebook on inference example, similar to segformer one (see examples/ folder)
  2. We need some docs and table to clarify which encoders are supported, cause it's a bit different from other models which support convolutional and transformers encoders
  3. Refine tests a bit (see comment below)

Other than that it looks clean, thank you for your hard work 🙌

intermediates_only=True,
)

cls_tokens = [None] * len(self.out_indices)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to provide CLS tokens?

Copy link
Author

@vedantdalimkar vedantdalimkar Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DPT architecture requires CLS tokens. The motivation for the same is provided in the DPT paper -

... the readout token doesn’t serve a clear purpose for the task of dense prediction, but could potentially still be useful to capture and distribute global information.

The default setting of DPT architecture broadcasts the CLS token and adds it to the patch features along the feature dimension before projecting the patch features back to the original feature dimension.

This is why CLS token is needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks!

Comment on lines +163 to +177
if self.num_prefix_tokens > 0:
features, prefix_tokens = zip(*intermediate_outputs)
if self.cls_token_supported:
if self.num_prefix_tokens == 1:
cls_tokens = prefix_tokens

elif self.num_prefix_tokens > 1:
cls_tokens = [
prefix_token[:, 0, :] for prefix_token in prefix_tokens
]

else:
features = intermediate_outputs

return features, cls_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we can return just features without prefix tokens, isn't it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There can be 2 possibilities wrt prefix tokens in a timm ViT model -

  1. No CLS tokens, no register tokens
  2. CLS token present, multiple register tokens
  3. No CLS token, no register token.

I have tried to include all edge cases so that only features and CLS token is returned and not the register tokens.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks for the clarification, let's add this as a comment to the code

Comment on lines +100 to +107
else:
self.upsample = nn.Conv2d(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=3,
stride=int(1 / upsample_factor),
padding=1,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, interesting is this code reachable?

Copy link
Author

@vedantdalimkar vedantdalimkar Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so. Let me provide a bit of explanation so it will be much clearer.

The purpose of FeatureProcessBlock layer is to upsample the patch features (which are the outputs of the transformer layers) so that the features passed to the DPT decoder are multi-scale. This means that each FeatureProcessBlock layer would either upsample or downsample the patch features in such a way that the features given to the decoder have the specific spatial ratios relative to the the input image spatial dimension - [1/4,1/8,1/16,1/32]

This is how the upsample factor is calculated for each FeatureProcessBlock layer of the decoder -

upsample_factors = [
            (encoder_output_stride / 2 ** (index + 2))
            for index in range(0, encoder_depth)
        ]

So for a ViT encoder with an output stride of 16, the upsample factors would be [4,2,1,1/2]. Thus, for the last FeatureProcessBlock layer, the upsample factor would be 1/2 and thus the code in the last else block would be reachable. (the stride for the conv layer inside the FeatureProcessBlock layer would be 2)

Also, it is possible that (encoder_output_stride / 2 ** (index + 2)) is not a power of 2, but I have put a constraint on the encoder output stride so that only ViT encoders with output stride being a power of 2 are allowed for DPT.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sounds good then, thanks for taking time to clarify 👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To remove?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, my bad. I need to update this.

Comment on lines 32 to 51
@property
def model_class(self):
return smp.MODEL_ARCHITECTURES_MAPPING[self.model_type]

@property
def decoder_channels(self):
signature = inspect.signature(self.model_class)
# check if decoder_channels is in the signature
if "decoder_channels" in signature.parameters:
return signature.parameters["decoder_channels"].default
return None

@lru_cache
def _get_sample(self, batch_size=None, num_channels=None, height=None, width=None):
batch_size = batch_size or self.default_batch_size
num_channels = num_channels or self.default_num_channels
height = height or self.default_height
width = width or self.default_width
return torch.rand(batch_size, num_channels, height, width)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to override this, the same as in base class

Comment on lines 60 to 79
def test_forward_backward(self):
sample = self._get_sample().to(default_device)

model = self.get_default_model()

# check default in_channels=3
output = model(sample)

# check default output number of classes = 1
expected_number_of_classes = 1
result_number_of_classes = output.shape[1]
self.assertEqual(
result_number_of_classes,
expected_number_of_classes,
f"Default output number of classes should be {expected_number_of_classes}, but got {result_number_of_classes}",
)

# check backward pass
output.mean().backward()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to override

Suggested change
def test_forward_backward(self):
sample = self._get_sample().to(default_device)
model = self.get_default_model()
# check default in_channels=3
output = model(sample)
# check default output number of classes = 1
expected_number_of_classes = 1
result_number_of_classes = output.shape[1]
self.assertEqual(
result_number_of_classes,
expected_number_of_classes,
f"Default output number of classes should be {expected_number_of_classes}, but got {result_number_of_classes}",
)
# check backward pass
output.mean().backward()

Comment on lines 233 to 276
@pytest.mark.torch_export
def test_torch_export(self):
if not check_run_test_on_diff_or_main(self.files_for_diff):
self.skipTest("No diff and not on `main`.")

sample = self._get_sample().to(default_device)
model = self.get_default_model()
model.eval()

exported_model = torch.export.export(
model,
args=(sample,),
strict=True,
)

with torch.inference_mode():
eager_output = model(sample)
exported_output = exported_model.module().forward(sample)

self.assertEqual(eager_output.shape, exported_output.shape)
torch.testing.assert_close(eager_output, exported_output)

@pytest.mark.torch_script
def test_torch_script(self):
if not check_run_test_on_diff_or_main(self.files_for_diff):
self.skipTest("No diff and not on `main`.")

sample = self._get_sample().to(default_device)
model = self.get_default_model()
model.eval()

if not model._is_torch_scriptable:
with self.assertRaises(RuntimeError):
scripted_model = torch.jit.script(model)
return

scripted_model = torch.jit.script(model)

with torch.inference_mode():
scripted_output = scripted_model(sample)
eager_output = model(sample)

self.assertEqual(scripted_output.shape, eager_output.shape)
torch.testing.assert_close(scripted_output, eager_output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to override, please leave only modified tests here, other tests will be fetched from base class

Comment on lines 54 to 56
model = smp.create_model(
self.model_type, self.test_encoder_name, output_stride=None
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid passing output_stride=None to make it consistent with other models? It can be handled on create_model level or encoder level for example

@vedantdalimkar
Copy link
Author

vedantdalimkar commented Mar 18, 2025

Hey @qubvel , I had a concern regarding the following point.

  1. Conversion script and integration test to ensure model produces the same logits as the original one. In addition would be great to create a notebook on inference example, similar to segformer one (see examples/ folder)

The DPT architecture has a different segmentation head compared to the standard SMP segmentation head. I am guessing, I would need to keep the head same as the original architecture in order to ensure that logits match with the original model. Is it fine if the smp.DPT model has a different segmentation head?

If yes, I should include the new segmentation head class under decoders/dpt/head.py right?

Edit: In the case that the hub repo for the model doesn't have any output/input tensors, what should be done to test consistency between hub model output and smp model output?

@qubvel
Copy link
Collaborator

qubvel commented Mar 18, 2025

Is it fine if the smp.DPT model has a different segmentation head?
If yes, I should include the new segmentation head class under decoders/dpt/head.py right?

Yes, it's fine to have it's own head, you can put it in decoder.py

In the case that the hub repo for the model doesn't have any output/input tensors, what should be done to test consistency between hub model output and smp model output?

You can run HF model on all ones tensor to get expected output, then use it in the test, please see Segformer test

def test_load_pretrained(self):
hub_checkpoint = "smp-hub/segformer-b0-512x512-ade-160k"
model = smp.from_pretrained(hub_checkpoint)
model = model.eval().to(default_device)
sample = torch.ones([1, 3, 512, 512]).to(default_device)
with torch.inference_mode():
output = model(sample)
self.assertEqual(output.shape, (1, 150, 512, 512))
expected_logits_slice = torch.tensor(
[-4.4172, -4.4723, -4.5273, -4.5824, -4.6375, -4.7157]
)
resulted_logits_slice = output[0, 0, 256, :6].cpu()
is_equal = torch.allclose(
expected_logits_slice, resulted_logits_slice, atol=1e-2
)
max_diff = torch.max(torch.abs(expected_logits_slice - resulted_logits_slice))
self.assertTrue(
is_equal,
f"Expected logits slice and resulted logits slice are not equal.\n"
f"Max diff: {max_diff}\n"
f"Expected: {expected_logits_slice}\n"
f"Resulted: {resulted_logits_slice}\n",
)

@vedantdalimkar
Copy link
Author

I am still confused regarding a couple of things -

  1. Conversion script and integration test to ensure model produces the same logits as the original one. In addition would be great to create a notebook on inference example, similar to segformer one (see examples/ folder)

i) By "original" model, do you mean the model uploaded on HF by the paper authors (who wrote the paper for DPT) ?
ii) Right now, smp-hub doesn't have any model for DPT, is it supposed to uploaded by us?

@qubvel
Copy link
Collaborator

qubvel commented Mar 18, 2025

i) By "original" model, do you mean the model uploaded on HF by the paper authors (who wrote the paper for DPT) ?

Any of those, should be equivalent

ii) Right now, smp-hub doesn't have any model for DPT, is it supposed to uploaded by us?

Yes, you can upload it to your own HF account, I will transfer it to smp-hub as soon as PR is ready 🤗

@vedantdalimkar
Copy link
Author

vedantdalimkar commented Mar 22, 2025

Hi @vedantdalimkar! Great work and sorry once again for the delay 🤗 Thank you for working on the model, looks super good for the first iteration, I think we need just a few steps to make it merged 🚀

Here is what's missing:

  1. Conversion script and integration test to ensure model produces the same logits as the original one. In addition would be great to create a notebook on inference example, similar to segformer one (see examples/ folder)
  2. We need some docs and table to clarify which encoders are supported, cause it's a bit different from other models which support convolutional and transformers encoders
  3. Refine tests a bit (see comment below)

Other than that it looks clean, thank you for your hard work 🙌

Hey @qubvel, I have worked on the points mentioned above, hopefully should pass the required tests now (except torch.compile test)

  1. For generating logits of original model, I have used the segmentation model in the original DPT repository. I have used the dpt_large-ade20k-b12dca68.pt checkpoint for weight conversion.

Sample code which I have used for logit generation from original model -

Code
  from dpt.models import DPTSegmentationModel
  import torch
  input = torch.ones((1,3,384,384))
  model = DPTSegmentationModel(num_classes = 150,path = r"C:\Users\vedan\Downloads\dpt_large-ade20k-b12dca68.pt",backbone="vitl16_384",
          )
  
  model.eval()
  with torch.no_grad():
      output = model(input)

I have also provided the model conversion script in the misc folder. I have uploaded the model after weight conversion here - https://huggingface.co/vedantdalimkar/DPT

  1. I have added some logic in the generate_timm_tables.py script so that the table contains ViT-like encoders in timm which can be used with DPT.
    I faced quite a lot of issues while generating the smp encoders tablegenerate_table.py since I am working on a Windows OS. If possible, can you please generate the smp encoders table for me from your end?

  2. I have made the suggested changes for the tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants