In [7]:
from transformers import DetrForObjectDetection, DetrConfig
from torchinfo import summary
import torch
import torch.nn as nn

def prune_resnet_backbone(model, pruning_dict):
    """
    Updated for Hugging Face DetrForObjectDetection structure.
    pruning_dict: {stage_number: blocks_to_keep}
    """
    # 1. Reach the internal DETR model
    # If using PEFT, we go through get_base_model()
    curr_model = model.get_base_model() if hasattr(model, "get_base_model") else model
    
    # 2. Correct path to the ResNet object in HF Transformers
    # model.model.backbone -> DetrConvEncoder
    # model.model.backbone.backbone -> DetrResnetBackbone
    # model.model.backbone.backbone.model -> The actual ResNet with layer1, layer2, etc.
    try:
        backbone = curr_model.model.backbone.backbone.model
    except AttributeError:
        # Fallback for some versions of the library
        backbone = curr_model.model.backbone.model
        
    print(f"Successfully reached backbone: {type(backbone).__name__}")

    for stage_num, keep_count in pruning_dict.items():
        stage_name = f"layer{stage_num}"
        if not hasattr(backbone, stage_name):
            print(f"Warning: Stage {stage_num} ({stage_name}) not found. Skipping.")
            continue
            
        stage = getattr(backbone, stage_name)
        total_blocks = len(stage)
        
        # Guard: Ensure we keep at least the downsampling block (index 0)
        if keep_count < 1:
            keep_count = 1
            
        if keep_count >= total_blocks:
            continue

        # Replace unwanted blocks with Identity
        for i in range(keep_count, total_blocks):
            stage[i] = nn.Identity()
            
        print(f"Pruned Stage {stage_num}: Reduced from {total_blocks} to {keep_count} active blocks.")

ABLATION_DICT = {
    1: 1,
    2: 1,
    3: 1,
    4: 1
}

config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
config.num_labels = 2

# Load the model (or use your existing 'model' variable)
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
#model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config, ignore_mismatched_sizes=True)  # Adjust num_labels as needed

prune_resnet_backbone(model, ABLATION_DICT)

# This prints the entire layer-by-layer architecture
print(model)

# Create a dummy input to trace the shapes
# DETR expects [batch, channels, height, width]
dummy_input = torch.randn(1, 3, 800, 800) 

summary(model, input_data=dummy_input, depth=3)

Loading weights: 100%|██████████| 530/530 [00:00<00:00, 684.76it/s, Materializing param=model.query_position_embeddings.weight]                 
[1mDetrForObjectDetection LOAD REPORT[0m from: facebook/detr-resnet-50
Key                                                            | Status     |  | 
---------------------------------------------------------------+------------+--+-
model.backbone.model.layer3.0.downsample.1.num_batches_tracked | UNEXPECTED |  | 
model.backbone.model.layer1.0.downsample.1.num_batches_tracked | UNEXPECTED |  | 
model.backbone.model.layer4.0.downsample.1.num_batches_tracked | UNEXPECTED |  | 
model.backbone.model.layer2.0.downsample.1.num_batches_tracked | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


Successfully reached backbone: FeatureListNet
Pruned Stage 1: Reduced from 3 to 1 active blocks.
Pruned Stage 2: Reduced from 4 to 1 active blocks.
Pruned Stage 3: Reduced from 6 to 1 active blocks.
Pruned Stage 4: Reduced from 3 to 1 active blocks.
DetrForObjectDetection(
  (model): DetrModel(
    (backbone): DetrConvEncoder(
      (model): FeatureListNet(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): DetrFrozenBatchNorm2d()
        (act1): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): Bottleneck(
            (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn1): DetrFrozenBatchNorm2d()
            (act1): ReLU(inplace=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): DetrFrozenBatchNorm2d()
            (d

Layer (type:depth-idx)                                            Output Shape              Param #
DetrForObjectDetection                                            [1, 625, 256]             --
├─DetrModel: 1-1                                                  [1, 625, 256]             25,600
│    └─DetrConvEncoder: 2-1                                       [1, 256, 200, 200]        --
│    │    └─FeatureListNet: 3-1                                   [1, 256, 200, 200]        7,996,608
│    └─Conv2d: 2-2                                                [1, 256, 25, 25]          524,544
│    └─DetrSinePositionEmbedding: 2-3                             [1, 625, 256]             --
│    └─DetrEncoder: 2-4                                           [1, 625, 256]             --
│    │    └─ModuleList: 3-2                                       --                        7,890,432
│    └─DetrDecoder: 2-5                                           [1, 100, 256]             --
│    │    └─ModuleList