Now I want to take the indices and ablate the resnet model

In [1]:
from safetensors.torch import safe_open
import torch

In [49]:
ablate = {}
top_percent = 1
with safe_open(f"./data/tasks/face_task2k_global_top_{top_percent}_percent.safetensors", framework="pt") as f:
    for k in f.keys():
        ablate[k] = f.get_tensor(k)
ablate

{'layer0': tensor([  4569,   4568,   4570, 156265, 156266, 156270,   5567, 156269, 183603,
         156267, 156264, 156268,   4558,   4559, 156209, 156271, 183662, 156321,
         156326, 156214, 156322,   5577, 156325, 183642, 156101, 156210, 156102,
           4739, 156098, 156097,   5576, 156153, 156213, 156158,   4567, 156157,
         156099, 156100, 156154, 156156, 156323, 156208, 156212, 156324, 156211,
         183641, 156155, 196082, 183659,   4740, 183589,   4557, 156215, 196081,
         156263,   5566, 156152, 156382, 156320, 156096, 156327,   4724,   4722,
         156272, 196026, 156103, 183661,   4741, 156381, 156207, 156378,   4624,
         156377, 156380, 183644, 156159, 183645,  40122, 196025, 156379, 155884,
           6133,  40123,   4571,  40125, 183586, 156383, 155940,  40124, 183602,
         155885,   4560, 156216,   5568, 155828, 156151, 183660, 155883, 155867,
         156046,   6132, 156095, 155941, 155830, 155886, 155774, 156328, 155829,
         156376, 1

In [50]:
from transformers import AutoImageProcessor, ResNetForImageClassification

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = AutoImageProcessor.from_pretrained("microsoft/resnet-34")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-34")
model = model.to(DEVICE)

In [51]:
from PIL import Image
@torch.no_grad()
def forward(image):
    model.eval()
    inputs = processor(image, return_tensors="pt").to(DEVICE)
    logits = model(**inputs).logits
    labels = logits.argmax(-1).cpu().tolist()
    return labels
fetch_data = lambda path: Image.open(f"./data/{path}").convert("RGB")

In [61]:
class AblateTorchModel:
    def __init__(self, layers, ablate, scaler=0):
        self.layers = layers
        self.ablate = ablate
        self.hooks = []
        for l, a in zip(layers, ablate):
            self.hooks.append(
                self.register_hook(l, a, scaler)
            )

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks = []

    def register_hook(self, layer, ablate, scaler):
        def hook(module, inputs, outputs):
            if len(ablate) > 0:
                B = outputs.shape[0]
                flat = outputs.view(B, -1)
                flat[:, ablate] *= scaler
            return outputs
        return layer.register_forward_hook(hook)
    
    def __enter__(self):
        return self
    def __exit__(self, exception_type, exception_value, exception_traceback):
        self.remove_hooks()

layers = [layer for stage in model.resnet.encoder.stages for layer in stage.layers]
to_ablate_flat_idxs = [ablate[f"layer{i}"] for i in range(16)]

path = "dog.jpeg"
with AblateTorchModel(layers, to_ablate_flat_idxs) as a:
    image = fetch_data(path)
    labels = forward(image)
    print("ablated", model.config.id2label[labels[0]])

image = fetch_data(path)
labels = forward(image)
print("regular", model.config.id2label[labels[0]])

ablated beagle
regular beagle


In [53]:
import pandas as pd
df = pd.read_parquet("./data/tasks/face_task2k_with_labels.parquet")

In [54]:
df

Unnamed: 0,data,positive,label
0,celeba/398.jpg,True,585
1,celeba/3833.jpg,True,906
2,celeba/4836.jpg,True,903
3,celeba/4572.jpg,True,459
4,celeba/636.jpg,True,903
...,...,...,...
1995,val2017/000000294855.jpg,False,764
1996,val2017/000000142790.jpg,False,802
1997,val2017/000000297085.jpg,False,851
1998,val2017/000000140286.jpg,False,603


In [55]:
N = 100
pos_examples = df[df["positive"] == True][:N]
neg_examples = df[df["positive"] == False][:N]

In [56]:
pos_examples

Unnamed: 0,data,positive,label
0,celeba/398.jpg,True,585
1,celeba/3833.jpg,True,906
2,celeba/4836.jpg,True,903
3,celeba/4572.jpg,True,459
4,celeba/636.jpg,True,903
...,...,...,...
95,celeba/2929.jpg,True,903
96,celeba/2814.jpg,True,655
97,celeba/4214.jpg,True,459
98,celeba/2831.jpg,True,937


In [57]:
neg_examples


Unnamed: 0,data,positive,label
1000,val2017/000000579307.jpg,False,977
1001,val2017/000000127517.jpg,False,693
1002,val2017/000000569565.jpg,False,555
1003,val2017/000000114770.jpg,False,586
1004,val2017/000000037751.jpg,False,671
...,...,...,...
1095,val2017/000000147518.jpg,False,861
1096,val2017/000000331352.jpg,False,861
1097,val2017/000000071756.jpg,False,362
1098,val2017/000000320425.jpg,False,293


In [58]:
def compute_ablated_and_regular(df):
    paths = df["data"].tolist()
    images = [fetch_data(p) for p in paths]
    labels = forward(images)


    layers = [layer for stage in model.resnet.encoder.stages for layer in stage.layers]
    to_ablate_flat_idxs = [ablate[f"layer{i}"] for i in range(16)]

    ablated_labels = None
    with AblateTorchModel(layers, to_ablate_flat_idxs, 0) as a:
        ablated_labels = forward(images)

    return labels, ablated_labels

pos_res = compute_ablated_and_regular(pos_examples)
neg_res = compute_ablated_and_regular(neg_examples)

In [59]:
def compute_acc(y, yhat):
	correct = y==yhat
	return correct.sum() / len(y)
print(compute_acc(torch.tensor(pos_res[0]), torch.tensor(pos_res[1])))
print(compute_acc(torch.tensor(neg_res[0]), torch.tensor(neg_res[1])))

tensor(0.)
tensor(0.5900)
