Skip to content

Handle input channel size while loading ckpts #137

@gitttt-1234

Description

@gitttt-1234

We currently support loading pre-trained checkpoints for both the backbone and head layers, provided the input layer shapes are compatible between the model and checkpoint. However, we need to extend this capability to handle cases where the checkpoint has a single-channel input layer, while the model has a 3-channel input layer (and vice versa). We need to implement strategies for channel reduction/ expansion for the input layer.

  • # Initializing backbone (encoder + decoder) with trained ckpts
    if self.pretrained_backbone_weights is not None:
    logger.info(
    f"Loading backbone weights from `{self.pretrained_backbone_weights}` ..."
    )
    ckpt = torch.load(
    self.pretrained_backbone_weights,
    map_location=self.trainer_accelerator,
    weights_only=False,
    )
    ckpt["state_dict"] = {
    k: ckpt["state_dict"][k]
    for k in ckpt["state_dict"].keys()
    if ".backbone" in k
    }
    self.load_state_dict(ckpt["state_dict"], strict=False)

This should also be handled for convnext and swint architectures (and remove input_expand_channels)

  • if self.backbone_type == "convnext" or self.backbone_type == "swint":
    if (
    self.backbone_config[f"{self.backbone_type}"]["pre_trained_weights"]
    is not None
    ):
    ckpt = MODEL_WEIGHTS[
    self.backbone_config[f"{self.backbone_type}"]["pre_trained_weights"]
    ].DEFAULT.get_state_dict(progress=True, check_hash=True)
    input_channels = ckpt["features.0.0.weight"].shape[-3]
    if self.in_channels != input_channels: # TODO: not working!
    self.input_expand_channels = input_channels

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions