In [1]:
from transformers import AutoImageProcessor, ResNetForImageClassification
import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = AutoImageProcessor.from_pretrained("microsoft/resnet-34")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-34")
model = model.to(DEVICE)

2025-03-02 23:55:23.188539: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
from PIL import Image
@torch.no_grad()
def forward(image):
    model.eval()
    inputs = processor(image, return_tensors="pt").to(DEVICE)
    logits = model(**inputs).logits
    labels = logits.argmax(-1).cpu().tolist()
    return labels
fetch_data = lambda path: Image.open(f"./data/{path}").convert("RGB")

In [46]:
class AblateTorchModel:
    def __init__(self, layers, ablate, scaler=0):
        self.layers = layers
        self.ablate = ablate
        self.hooks = []
        for l, a in zip(layers, ablate):
            self.hooks.append(
                self.register_hook(l, a, scaler)
            )

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks = []

    def register_hook(self, layer, ablate, scaler):
        def hook(module, inputs, outputs):
            B = outputs.shape[0]
            flat = outputs.view(B, -1)
            flat[:, ablate] *= scaler
            return outputs
        return layer.register_forward_hook(hook)
    
    def __enter__(self):
        return self
    def __exit__(self, exception_type, exception_value, exception_traceback):
        self.remove_hooks()

layers = [layer for stage in model.resnet.encoder.stages for layer in stage.layers]
to_ablate_flat_idxs = [[i for i in range(2_000)] for _ in layers]
with AblateTorchModel(layers, to_ablate_flat_idxs, -1) as a:
    image = fetch_data("dog.jpeg")
    labels = forward(image)
    print("predicted", model.config.id2label[labels[0]])
    

predicted bath towel
