-
Notifications
You must be signed in to change notification settings - Fork 7
Description
We currently support loading pre-trained checkpoints for both the backbone and head layers, provided the input layer shapes are compatible between the model and checkpoint. However, we need to extend this capability to handle cases where the checkpoint has a single-channel input layer, while the model has a 3-channel input layer (and vice versa). We need to implement strategies for channel reduction/ expansion for the input layer.
sleap-nn/sleap_nn/training/lightning_modules.py
Lines 212 to 227 in 69f3324
# Initializing backbone (encoder + decoder) with trained ckpts if self.pretrained_backbone_weights is not None: logger.info( f"Loading backbone weights from `{self.pretrained_backbone_weights}` ..." ) ckpt = torch.load( self.pretrained_backbone_weights, map_location=self.trainer_accelerator, weights_only=False, ) ckpt["state_dict"] = { k: ckpt["state_dict"][k] for k in ckpt["state_dict"].keys() if ".backbone" in k } self.load_state_dict(ckpt["state_dict"], strict=False)
This should also be handled for convnext and swint architectures (and remove input_expand_channels)
sleap-nn/sleap_nn/training/lightning_modules.py
Lines 166 to 176 in 69f3324
if self.backbone_type == "convnext" or self.backbone_type == "swint": if ( self.backbone_config[f"{self.backbone_type}"]["pre_trained_weights"] is not None ): ckpt = MODEL_WEIGHTS[ self.backbone_config[f"{self.backbone_type}"]["pre_trained_weights"] ].DEFAULT.get_state_dict(progress=True, check_hash=True) input_channels = ckpt["features.0.0.weight"].shape[-3] if self.in_channels != input_channels: # TODO: not working! self.input_expand_channels = input_channels