-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ObjectDetectionTask support for MSI #1156
Comments
The adaption for multi-channel looks straightforward - not sure about handling the pretrained weights https://github.com/allenai/vessel-detection-sentinels/blob/main/src/models/frcnn.py |
I am taking a stab at this (offline) - appears straightforward: def configure_models(self) -> None:
"""Initialize the model.
Raises:
ValueError: If *model* or *backbone* are invalid.
"""
backbone: str = self.hparams['backbone']
model: str = self.hparams['model']
weights: bool | None = self.hparams['weights']
in_channels: int = self.hparams['in_channels']
num_classes: int = self.hparams['num_classes']
freeze_backbone: bool = self.hparams['freeze_backbone']
if backbone in BACKBONE_LAT_DIM_MAP:
kwargs = {
'backbone_name': backbone,
'trainable_layers': self.hparams['trainable_layers'],
}
if weights:
kwargs['weights'] = BACKBONE_WEIGHT_MAP[backbone]
else:
kwargs['weights'] = None
latent_dim = BACKBONE_LAT_DIM_MAP[backbone]
else:
raise ValueError(f"Backbone type '{backbone}' is not valid.")
if model == 'faster-rcnn':
model_backbone = resnet_fpn_backbone(**kwargs)
if in_channels != 3: # Adjust the first conv layer to match input channels
first_conv_layer = model_backbone.body.conv1
model_backbone.body.conv1 = torch.nn.Conv2d(
in_channels,
first_conv_layer.out_channels,
kernel_size=first_conv_layer.kernel_size,
stride=first_conv_layer.stride,
padding=first_conv_layer.padding,
bias=False
)
anchor_generator = AnchorGenerator(
sizes=((32), (64), (128), (256), (512)), aspect_ratios=((0.5, 1.0, 2.0))
)
roi_pooler = MultiScaleRoIAlign(
featmap_names=['0', '1', '2', '3'], output_size=7, sampling_ratio=2
)
if freeze_backbone:
for param in model_backbone.parameters():
param.requires_grad = False
self.model = torchvision.models.detection.FasterRCNN(
model_backbone,
num_classes,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler,
)
else:
raise ValueError(f"Model type '{model}' is not valid.") However in use I get an error:
Since kornia performs normalisation, why us torchvision raising this error? |
Appears torchvision is also performing norm - I get around this with a dummy: from torchvision.models.detection.transform import GeneralizedRCNNTransform
class NoNormalizeTransform(GeneralizedRCNNTransform):
def normalize(self, image):
# Skip normalization, return the image as is
return image
... add to config
self.model.transform = NoNormalizeTransform(
min_size=800,
max_size=1333,
image_mean=[0.0, 0.0, 0.0], # Dummy values, won't be used
image_std=[1.0, 1.0, 1.0] # Dummy values, won't be used
) |
Would need to see the full traceback and code to reproduce the bug you saw. |
I just tried this and it was much harder than expected. It isn't just a vanilla ResNet being used, they modify it to return info with the right dimensions to the RPN head. Might have to give up on timm until someone makes an SMP equivalent for detection. |
Fix is straightforward, see what I did in #2513. Testing is less straightforward. At the moment we don't have any non-RGB object detection datasets in TorchGeo. @robmarkcole suggested adding https://github.com/alina2204/contrastive_SSL_ship_detection if anyone has the time. |
Summary
The ObjectDetectionTask has an in_channels parameter but it isn't actually used for anything. At the moment, it seems that the trainer only supports RGB imagery. We should fix this.
Rationale
We specialize in MSI, how can we not support MSI.
Implementation
We're currently using torchvision backbones, which makes things more challenging. How hard would it be to switch to timm backbones?
Alternatives
Alternatively, we'll have to override the first channel of the torchvision backbone.
Additional information
No response
The text was updated successfully, but these errors were encountered: