-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adding DPT #1079
base: main
Are you sure you want to change the base?
[WIP] Adding DPT #1079
Conversation
Hi @vedantdalimkar. Thanks a lot for the PR 🤗 I am a bit too busy this week, so will review it early next week. Thanks for your patience. Meanwhile, can you please set up tests for the encoder and the DPT model? Please see how other models are tested |
Hi @qubvel. Sure, I will set up the relevant tests. |
Hi @qubvel, did some refactoring. This commit should pass majority of the tests now. Had a few questions -
Some issues I faced with the default environment for the smp development.
|
Hi @qubvel. Just a gentle reminder, I think this PR is ready for review. |
Hey @vedantdalimkar, thanks for the ping, and sorry for the delay! Will try to do a pass today or on Monday. Thank you for your patience 🤗 |
Codecov ReportAttention: Patch coverage is
... and 2 files with indirect coverage changes 🚀 New features to boost your workflow:
|
def load_state_dict(self, state_dict, **kwargs): | ||
# for compatibility of weights for | ||
# timm- ported encoders with TimmUniversalEncoder | ||
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"] | ||
|
||
is_deprecated_encoder = any( | ||
self.name.startswith(pattern) for pattern in patterns | ||
) | ||
|
||
if is_deprecated_encoder: | ||
keys = list(state_dict.keys()) | ||
for key in keys: | ||
new_key = key | ||
if not key.startswith("model."): | ||
new_key = "model." + key | ||
if "gernet" in self.name: | ||
new_key = new_key.replace(".stages.", ".stages_") | ||
state_dict[new_key] = state_dict.pop(key) | ||
|
||
return super().load_state_dict(state_dict, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed for this kind of encoder
def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: | ||
""" | ||
Merge two dictionaries, ensuring no duplicate keys exist. | ||
|
||
Args: | ||
a (dict): Base dictionary. | ||
b (dict): Additional parameters to merge. | ||
|
||
Returns: | ||
dict: A merged dictionary. | ||
""" | ||
duplicates = a.keys() & b.keys() | ||
if duplicates: | ||
raise ValueError(f"'{duplicates}' already specified internally") | ||
|
||
return a | b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be imported if needed, no need to duplicate code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @vedantdalimkar! Great work and sorry once again for the delay 🤗 Thank you for working on the model, looks super good for the first iteration, I think we need just a few steps to make it merged 🚀
Here is what's missing:
- Conversion script and integration test to ensure model produces the same logits as the original one. In addition would be great to create a notebook on inference example, similar to segformer one (see
examples/
folder) - We need some docs and table to clarify which encoders are supported, cause it's a bit different from other models which support convolutional and transformers encoders
- Refine tests a bit (see comment below)
Other than that it looks clean, thank you for your hard work 🙌
intermediates_only=True, | ||
) | ||
|
||
cls_tokens = [None] * len(self.out_indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to provide CLS tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DPT architecture requires CLS tokens. The motivation for the same is provided in the DPT paper -
... the readout token doesn’t serve a clear purpose for the task of dense prediction, but could potentially still be useful to capture and distribute global information.
The default setting of DPT architecture broadcasts the CLS token and adds it to the patch features along the feature dimension before projecting the patch features back to the original feature dimension.
This is why CLS token is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thanks!
if self.num_prefix_tokens > 0: | ||
features, prefix_tokens = zip(*intermediate_outputs) | ||
if self.cls_token_supported: | ||
if self.num_prefix_tokens == 1: | ||
cls_tokens = prefix_tokens | ||
|
||
elif self.num_prefix_tokens > 1: | ||
cls_tokens = [ | ||
prefix_token[:, 0, :] for prefix_token in prefix_tokens | ||
] | ||
|
||
else: | ||
features = intermediate_outputs | ||
|
||
return features, cls_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose we can return just features without prefix tokens, isn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There can be 2 possibilities wrt prefix tokens in a timm ViT model -
- No CLS tokens, no register tokens
- CLS token present, multiple register tokens
- No CLS token, no register token.
I have tried to include all edge cases so that only features and CLS token is returned and not the register tokens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, thanks for the clarification, let's add this as a comment to the code
else: | ||
self.upsample = nn.Conv2d( | ||
in_channels=out_channel, | ||
out_channels=out_channel, | ||
kernel_size=3, | ||
stride=int(1 / upsample_factor), | ||
padding=1, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, interesting is this code reachable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think so. Let me provide a bit of explanation so it will be much clearer.
The purpose of FeatureProcessBlock
layer is to upsample the patch features (which are the outputs of the transformer layers) so that the features passed to the DPT decoder are multi-scale. This means that each FeatureProcessBlock
layer would either upsample or downsample the patch features in such a way that the features given to the decoder have the specific spatial ratios relative to the the input image spatial dimension - [1/4,1/8,1/16,1/32]
This is how the upsample factor is calculated for each FeatureProcessBlock
layer of the decoder -
upsample_factors = [
(encoder_output_stride / 2 ** (index + 2))
for index in range(0, encoder_depth)
]
So for a ViT encoder with an output stride of 16, the upsample factors would be [4,2,1,1/2]. Thus, for the last FeatureProcessBlock
layer, the upsample factor would be 1/2 and thus the code in the last else block would be reachable. (the stride for the conv layer inside the FeatureProcessBlock
layer would be 2)
Also, it is possible that (encoder_output_stride / 2 ** (index + 2))
is not a power of 2, but I have put a constraint on the encoder output stride so that only ViT encoders with output stride being a power of 2 are allowed for DPT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, sounds good then, thanks for taking time to clarify 👍
encoders_table.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, my bad. I need to update this.
tests/models/test_dpt.py
Outdated
@property | ||
def model_class(self): | ||
return smp.MODEL_ARCHITECTURES_MAPPING[self.model_type] | ||
|
||
@property | ||
def decoder_channels(self): | ||
signature = inspect.signature(self.model_class) | ||
# check if decoder_channels is in the signature | ||
if "decoder_channels" in signature.parameters: | ||
return signature.parameters["decoder_channels"].default | ||
return None | ||
|
||
@lru_cache | ||
def _get_sample(self, batch_size=None, num_channels=None, height=None, width=None): | ||
batch_size = batch_size or self.default_batch_size | ||
num_channels = num_channels or self.default_num_channels | ||
height = height or self.default_height | ||
width = width or self.default_width | ||
return torch.rand(batch_size, num_channels, height, width) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to override this, the same as in base class
tests/models/test_dpt.py
Outdated
def test_forward_backward(self): | ||
sample = self._get_sample().to(default_device) | ||
|
||
model = self.get_default_model() | ||
|
||
# check default in_channels=3 | ||
output = model(sample) | ||
|
||
# check default output number of classes = 1 | ||
expected_number_of_classes = 1 | ||
result_number_of_classes = output.shape[1] | ||
self.assertEqual( | ||
result_number_of_classes, | ||
expected_number_of_classes, | ||
f"Default output number of classes should be {expected_number_of_classes}, but got {result_number_of_classes}", | ||
) | ||
|
||
# check backward pass | ||
output.mean().backward() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to override
def test_forward_backward(self): | |
sample = self._get_sample().to(default_device) | |
model = self.get_default_model() | |
# check default in_channels=3 | |
output = model(sample) | |
# check default output number of classes = 1 | |
expected_number_of_classes = 1 | |
result_number_of_classes = output.shape[1] | |
self.assertEqual( | |
result_number_of_classes, | |
expected_number_of_classes, | |
f"Default output number of classes should be {expected_number_of_classes}, but got {result_number_of_classes}", | |
) | |
# check backward pass | |
output.mean().backward() |
tests/models/test_dpt.py
Outdated
@pytest.mark.torch_export | ||
def test_torch_export(self): | ||
if not check_run_test_on_diff_or_main(self.files_for_diff): | ||
self.skipTest("No diff and not on `main`.") | ||
|
||
sample = self._get_sample().to(default_device) | ||
model = self.get_default_model() | ||
model.eval() | ||
|
||
exported_model = torch.export.export( | ||
model, | ||
args=(sample,), | ||
strict=True, | ||
) | ||
|
||
with torch.inference_mode(): | ||
eager_output = model(sample) | ||
exported_output = exported_model.module().forward(sample) | ||
|
||
self.assertEqual(eager_output.shape, exported_output.shape) | ||
torch.testing.assert_close(eager_output, exported_output) | ||
|
||
@pytest.mark.torch_script | ||
def test_torch_script(self): | ||
if not check_run_test_on_diff_or_main(self.files_for_diff): | ||
self.skipTest("No diff and not on `main`.") | ||
|
||
sample = self._get_sample().to(default_device) | ||
model = self.get_default_model() | ||
model.eval() | ||
|
||
if not model._is_torch_scriptable: | ||
with self.assertRaises(RuntimeError): | ||
scripted_model = torch.jit.script(model) | ||
return | ||
|
||
scripted_model = torch.jit.script(model) | ||
|
||
with torch.inference_mode(): | ||
scripted_output = scripted_model(sample) | ||
eager_output = model(sample) | ||
|
||
self.assertEqual(scripted_output.shape, eager_output.shape) | ||
torch.testing.assert_close(scripted_output, eager_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to override, please leave only modified tests here, other tests will be fetched from base class
tests/models/test_dpt.py
Outdated
model = smp.create_model( | ||
self.model_type, self.test_encoder_name, output_stride=None | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we avoid passing output_stride=None
to make it consistent with other models? It can be handled on create_model
level or encoder level for example
Hey @qubvel , I had a concern regarding the following point.
The DPT architecture has a different segmentation head compared to the standard SMP segmentation head. I am guessing, I would need to keep the head same as the original architecture in order to ensure that logits match with the original model. Is it fine if the smp.DPT model has a different segmentation head? If yes, I should include the new segmentation head class under decoders/dpt/head.py right? Edit: In the case that the hub repo for the model doesn't have any output/input tensors, what should be done to test consistency between hub model output and smp model output? |
Yes, it's fine to have it's own head, you can put it in
You can run HF model on all ones tensor to get expected output, then use it in the test, please see Segformer test segmentation_models.pytorch/tests/models/test_segformer.py Lines 16 to 43 in 4aa36c6
|
I am still confused regarding a couple of things -
i) By "original" model, do you mean the model uploaded on HF by the paper authors (who wrote the paper for DPT) ? |
Any of those, should be equivalent
Yes, you can upload it to your own HF account, I will transfer it to |
Hey @qubvel, I have worked on the points mentioned above, hopefully should pass the required tests now (except torch.compile test)
Sample code which I have used for logit generation from original model - Code
I have also provided the model conversion script in the misc folder. I have uploaded the model after weight conversion here - https://huggingface.co/vedantdalimkar/DPT
|
Tried to address issue #1073
I have 1 concern: Right now the model only works for images having the same resolution as the original input resolution on which a particular ViT encoder was trained. Is this behaviour okay or should I change the functionality so that a dynamic image resolution is supported?
I have tested various ViT encoders at different encoder depths and the model seems to run correctly.
@qubvel Please let me know if you feel I should make any changes.