Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adding DPT #1079

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -75,6 +75,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
*ipynb*

# pyenv
.python-version
1,474 changes: 1,474 additions & 0 deletions docs/timm_encoders.txt

Large diffs are not rendered by default.

59 changes: 51 additions & 8 deletions misc/generate_table_timm.py
Original file line number Diff line number Diff line change
@@ -17,30 +17,68 @@ def has_dilation_support(name):
return False


def valid_vit_encoder_for_dpt(name):
if "vit" not in name:
return False
encoder = timm.create_model(name)
feature_info = encoder.feature_info
feature_info_obj = timm.models.FeatureInfo(
feature_info=feature_info, out_indices=[0, 1, 2, 3]
)
reduction_scales = list(feature_info_obj.reduction())

if len(set(reduction_scales)) > 1:
return False

output_stride = reduction_scales[0]
if bin(output_stride).count("1") != 1:
return False

return True


def make_table(data):
names = data.keys()
max_len1 = max([len(x) for x in names]) + 2
max_len2 = len("support dilation") + 2
max_len3 = len("Supported for DPT") + 2

l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+\n"
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+\n"
l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+" + "-" * max_len3 + "+\n"
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+" + "-" * max_len3 + "+\n"
top = (
"| "
+ "Encoder name".ljust(max_len1 - 2)
+ " | "
+ "Support dilation".center(max_len2 - 2)
+ " | "
+ "Supported for DPT".center(max_len3 - 2)
+ " |\n"
)

table = l1 + top + l2

for k in sorted(data.keys()):
support = (
"✅".center(max_len2 - 3)
if data[k]["has_dilation"]
else " ".center(max_len2 - 2)
if "has_dilation" in data[k] and data[k]["has_dilation"]:
support = "✅".center(max_len2 - 3)

else:
support = " ".center(max_len2 - 2)

if "supported_only_for_dpt" in data[k]:
supported_for_dpt = "✅".center(max_len3 - 3)

else:
supported_for_dpt = " ".center(max_len3 - 2)

table += (
"| "
+ k.ljust(max_len1 - 2)
+ " | "
+ support
+ " | "
+ supported_for_dpt
+ " |\n"
)
table += "| " + k.ljust(max_len1 - 2) + " | " + support + " |\n"
table += l1

return table
@@ -55,8 +93,13 @@ def make_table(data):
check_features_and_reduction(name)
has_dilation = has_dilation_support(name)
supported_models[name] = dict(has_dilation=has_dilation)

except Exception:
continue
try:
if valid_vit_encoder_for_dpt(name):
supported_models[name] = dict(supported_only_for_dpt=True)
except Exception:
continue
Comment on lines +96 to +102
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check only if we got an exception here?
Would it be better to make two independent checks?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you check the behaviour of functions check_features_and_reduction and valid_vit_encoder_for_dpt, their output is mutually exclusive. To be more detailed:

  1. check_features_and_reduction returns true only when reduction scales of a model are equal to [2, 4, 8, 16, 32], whereas,
  2. valid_vit_encoder_for_dpt returns false if the encoder has multiple reduction scales.

In short, a model which satisfies the conditions specified by check_features_and_reduction will never satisfy the conditions set by valid_vit_encoder_for_dpt and vice versa.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I suppose this code should be updated as well, because as far as I remember [4, 8, 16, 32] and [1, 2, 4, 8, 16, 32] reductions are also supported

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I suppose this code should be updated as well, because as far as I remember [4, 8, 16, 32] and [1, 2, 4, 8, 16, 32] reductions are also supported

Should I update this as well or will you do it from your end?


table = make_table(supported_models)
print(table)
100 changes: 100 additions & 0 deletions scripts/models-conversions/dpt-original-to-smp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import segmentation_models_pytorch as smp
import torch

MODEL_WEIGHTS_PATH = r"C:\Users\vedan\Downloads\dpt_large-ade20k-b12dca68.pt"
HF_HUB_PATH = "vedantdalimkar/DPT"

if __name__ == "__main__":
smp_model = smp.DPT(encoder_name="tu-vit_large_patch16_384", classes=150)
dpt_model_dict = torch.load(MODEL_WEIGHTS_PATH)

for layer_index in range(0, 4):
for param in [
"running_mean",
"running_var",
"num_batches_tracked",
"weight",
"bias",
]:
for block_index in [1, 2]:
for bn_index in [1, 2]:
# Assigning weights of 4th fusion layer of original model to 1st layer of SMP DPT model,
# Assigning weights of 3rd fusion layer of original model to 2nd layer of SMP DPT model ...
# and so on ...

# This is because order of calling fusion layers is reversed in original DPT implementation

dpt_model_dict[
f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.batch_norm_{bn_index}.{param}"
] = dpt_model_dict.pop(
f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.bn{bn_index}.{param}"
)

if param in ["weight", "bias"]:
if param == "weight":
for block_index in [1, 2]:
for conv_index in [1, 2]:
dpt_model_dict[
f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.conv_{conv_index}.{param}"
] = dpt_model_dict.pop(
f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.conv{conv_index}.{param}"
)

dpt_model_dict[
f"decoder.reassemble_blocks.{layer_index}.project_to_feature_dim.{param}"
] = dpt_model_dict.pop(f"scratch.layer{layer_index + 1}_rn.{param}")

dpt_model_dict[
f"decoder.fusion_blocks.{layer_index}.project.{param}"
] = dpt_model_dict.pop(
f"scratch.refinenet{4 - layer_index}.out_conv.{param}"
)

dpt_model_dict[
f"decoder.readout_blocks.{layer_index}.project.0.{param}"
] = dpt_model_dict.pop(
f"pretrained.act_postprocess{layer_index + 1}.0.project.0.{param}"
)

dpt_model_dict[
f"decoder.reassemble_blocks.{layer_index}.project_to_out_channel.{param}"
] = dpt_model_dict.pop(
f"pretrained.act_postprocess{layer_index + 1}.3.{param}"
)

if layer_index != 2:
dpt_model_dict[
f"decoder.reassemble_blocks.{layer_index}.upsample.{param}"
] = dpt_model_dict.pop(
f"pretrained.act_postprocess{layer_index + 1}.4.{param}"
)

# Changing state dict keys for segmentation head
dpt_model_dict = {
(
"segmentation_head.head" + name[len("scratch.output_conv") :]
if name.startswith("scratch.output_conv")
else name
): parameter
for name, parameter in dpt_model_dict.items()
}

# Changing state dict keys for encoder layers
dpt_model_dict = {
(
"encoder.model" + name[len("pretrained.model") :]
if name.startswith("pretrained.model")
else name
): parameter
for name, parameter in dpt_model_dict.items()
}

# Removing keys,value pairs associated with auxiliary head
dpt_model_dict = {
name: parameter
for name, parameter in dpt_model_dict.items()
if not name.startswith("auxlayer")
}

smp_model.load_state_dict(dpt_model_dict, strict=True)
smp_model.save_pretrained(HF_HUB_PATH, push_to_hub=True)
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
from .decoders.pan import PAN
from .decoders.upernet import UPerNet
from .decoders.segformer import Segformer
from .decoders.dpt import DPT
from .base.hub_mixin import from_pretrained

from .__version__ import __version__
@@ -34,6 +35,7 @@
PAN,
UPerNet,
Segformer,
DPT,
]
MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES}

@@ -84,6 +86,7 @@ def create_model(
"PAN",
"UPerNet",
"Segformer",
"DPT",
"from_pretrained",
"create_model",
"__version__",
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/decoders/dpt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .model import DPT

__all__ = ["DPT"]
Loading
Oops, something went wrong.