Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ObjectDetectionTask support for MSI #1156

Closed
adamjstewart opened this issue Mar 4, 2023 · 6 comments · Fixed by #2602
Closed

ObjectDetectionTask support for MSI #1156

adamjstewart opened this issue Mar 4, 2023 · 6 comments · Fixed by #2602
Labels
good first issue A good issue for a new contributor to work on trainers PyTorch Lightning trainers

Comments

@adamjstewart
Copy link
Collaborator

Summary

The ObjectDetectionTask has an in_channels parameter but it isn't actually used for anything. At the moment, it seems that the trainer only supports RGB imagery. We should fix this.

Rationale

We specialize in MSI, how can we not support MSI.

Implementation

We're currently using torchvision backbones, which makes things more challenging. How hard would it be to switch to timm backbones?

Alternatives

Alternatively, we'll have to override the first channel of the torchvision backbone.

Additional information

No response

@robmarkcole
Copy link
Contributor

The adaption for multi-channel looks straightforward - not sure about handling the pretrained weights

https://github.com/allenai/vessel-detection-sentinels/blob/main/src/models/frcnn.py

@robmarkcole
Copy link
Contributor

I am taking a stab at this (offline) - appears straightforward:

    def configure_models(self) -> None:
        """Initialize the model.

        Raises:
            ValueError: If *model* or *backbone* are invalid.
        """
        backbone: str = self.hparams['backbone']
        model: str = self.hparams['model']
        weights: bool | None = self.hparams['weights']
        in_channels: int = self.hparams['in_channels']
        num_classes: int = self.hparams['num_classes']
        freeze_backbone: bool = self.hparams['freeze_backbone']

        if backbone in BACKBONE_LAT_DIM_MAP:
            kwargs = {
                'backbone_name': backbone,
                'trainable_layers': self.hparams['trainable_layers'],
            }
            if weights:
                kwargs['weights'] = BACKBONE_WEIGHT_MAP[backbone]
            else:
                kwargs['weights'] = None

            latent_dim = BACKBONE_LAT_DIM_MAP[backbone]
        else:
            raise ValueError(f"Backbone type '{backbone}' is not valid.")

        if model == 'faster-rcnn':
            model_backbone = resnet_fpn_backbone(**kwargs)

            if in_channels != 3:  # Adjust the first conv layer to match input channels
                first_conv_layer = model_backbone.body.conv1
                model_backbone.body.conv1 = torch.nn.Conv2d(
                    in_channels,
                    first_conv_layer.out_channels,
                    kernel_size=first_conv_layer.kernel_size,
                    stride=first_conv_layer.stride,
                    padding=first_conv_layer.padding,
                    bias=False
                )

            anchor_generator = AnchorGenerator(
                sizes=((32), (64), (128), (256), (512)), aspect_ratios=((0.5, 1.0, 2.0))
            )

            roi_pooler = MultiScaleRoIAlign(
                featmap_names=['0', '1', '2', '3'], output_size=7, sampling_ratio=2
            )

            if freeze_backbone:
                for param in model_backbone.parameters():
                    param.requires_grad = False

            self.model = torchvision.models.detection.FasterRCNN(
                model_backbone,
                num_classes,
                rpn_anchor_generator=anchor_generator,
                box_roi_pool=roi_pooler,
            )
        else: 
            raise ValueError(f"Model type '{model}' is not valid.")

However in use I get an error:

    139 if image.dim() != 3:
    140     raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
--> 141 image = self.normalize(image)
    142 image, target_index = self.resize(image, target_index)
    143 images[i] = image

File /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/detection/transform.py:169, in GeneralizedRCNNTransform.normalize(self, image)
    167 mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
    168 std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
--> 169 return (image - mean[:, None, None]) / std[:, None, None]

RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

Since kornia performs normalisation, why us torchvision raising this error?

@robmarkcole
Copy link
Contributor

Appears torchvision is also performing norm - I get around this with a dummy:

from torchvision.models.detection.transform import GeneralizedRCNNTransform

class NoNormalizeTransform(GeneralizedRCNNTransform):
    def normalize(self, image):
        # Skip normalization, return the image as is
        return image

... add to config
            self.model.transform = NoNormalizeTransform(
                    min_size=800, 
                    max_size=1333, 
                    image_mean=[0.0, 0.0, 0.0],  # Dummy values, won't be used
                    image_std=[1.0, 1.0, 1.0]    # Dummy values, won't be used
                )

@adamjstewart
Copy link
Collaborator Author

Would need to see the full traceback and code to reproduce the bug you saw.

@adamjstewart
Copy link
Collaborator Author

How hard would it be to switch to timm backbones?

I just tried this and it was much harder than expected. It isn't just a vanilla ResNet being used, they modify it to return info with the right dimensions to the RPN head. Might have to give up on timm until someone makes an SMP equivalent for detection.

@adamjstewart
Copy link
Collaborator Author

Fix is straightforward, see what I did in #2513.

Testing is less straightforward. At the moment we don't have any non-RGB object detection datasets in TorchGeo. @robmarkcole suggested adding https://github.com/alina2204/contrastive_SSL_ship_detection if anyone has the time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue A good issue for a new contributor to work on trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants