Now I want to take the indices and ablate the resnet model

In [225]:
from safetensors.torch import safe_open
import torch

In [None]:
ablate = {}
top_percent = 0.5
with safe_open(f"./data/tasks/face_task2k_global_top_{top_percent}_percent.safetensors", framework="pt") as f:
    for k in f.keys():
        ablate[k] = f.get_tensor(k)
ablate

{'layer0': tensor([  4569,   4568,   4570, 156265, 156266, 156270,   5567, 156269, 183603,
         156267, 156264, 156268,   4558,   4559, 156209, 156271, 183662, 156321,
         156326, 156214, 156322,   5577, 156325, 183642, 156101, 156210, 156102,
           4739, 156098, 156097,   5576, 156153, 156213, 156158,   4567, 156157,
         156099, 156100, 156154, 156156, 156323, 156208, 156212]),
 'layer1': tensor([183641, 183662, 183661, 183642, 183585, 183640, 183663, 183603, 183602,
         183659, 183717,  32956, 199281, 183698, 199270, 199337, 142824, 183589,
         183588, 199338,  33012, 183644,   4568,  40122, 199269,   4569,  40069,
         199282, 183645, 183697,  40125,  40066, 182701, 182700, 182644, 182645,
         183716, 183586, 193671, 182641, 183718,  32957,  40124, 182647, 183702,
          40123, 152173,  33413, 193680,  40068,  40071, 183756, 182698, 199326,
         153025, 182648,  39052, 199325, 182642,  40067,   7870,  33248, 183699,
         199267,   792

In [251]:
from transformers import AutoImageProcessor, ResNetForImageClassification

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = AutoImageProcessor.from_pretrained("microsoft/resnet-34")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-34")
model = model.to(DEVICE)

In [252]:
from PIL import Image
@torch.no_grad()
def forward(image):
    model.eval()
    inputs = processor(image, return_tensors="pt").to(DEVICE)
    logits = model(**inputs).logits
    labels = logits.argmax(-1).cpu().tolist()
    return labels, logits.softmax(-1).cpu()
fetch_data = lambda path: Image.open(f"./data/{path}").convert("RGB")

In [253]:
class AblateTorchModel:
    def __init__(self, layers, ablate, scaler=0):
        self.layers = layers
        self.ablate = ablate
        self.hooks = []
        for l, a in zip(layers, ablate):
            self.hooks.append(
                self.register_hook(l, a, scaler)
            )

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks = []

    def register_hook(self, layer, ablate, scaler):
        def hook(module, inputs, outputs):
            if len(ablate) > 0:
                B = outputs.shape[0]
                flat = outputs.view(B, -1)
                flat[:, ablate] *= scaler
            return outputs
        return layer.register_forward_hook(hook)
    
    def __enter__(self):
        return self
    def __exit__(self, exception_type, exception_value, exception_traceback):
        self.remove_hooks()

layers = [layer for stage in model.resnet.encoder.stages for layer in stage.layers]
to_ablate_flat_idxs = [ablate[f"layer{i}"] for i in range(16)]

path = "dog.jpeg"
with AblateTorchModel(layers, to_ablate_flat_idxs) as a:
    image = fetch_data(path)
    labels, _ = forward(image)
    print("ablated", model.config.id2label[labels[0]])

image = fetch_data(path)
labels, _ = forward(image)
print("regular", model.config.id2label[labels[0]])

ablated beagle
regular beagle


In [254]:
import pandas as pd
df = pd.read_parquet("./data/tasks/face_task.parquet")

In [255]:
df

Unnamed: 0,data,positive
0,celeba/0.jpg,True
1,celeba/1.jpg,True
2,celeba/2.jpg,True
3,celeba/3.jpg,True
4,celeba/4.jpg,True
...,...,...
9995,val2017/000000512403.jpg,False
9996,val2017/000000168974.jpg,False
9997,val2017/000000552775.jpg,False
9998,val2017/000000394940.jpg,False


In [256]:
import numpy as np
np.random.seed(0)
N = 200
pos_examples = df[df["positive"] == True].sample(N)
neg_examples = df[df["positive"] == False].sample(N)

In [257]:
pos_examples

Unnamed: 0,data,positive
398,celeba/398.jpg,True
3833,celeba/3833.jpg,True
4836,celeba/4836.jpg,True
4572,celeba/4572.jpg,True
636,celeba/636.jpg,True
...,...,...
2776,celeba/2776.jpg,True
4506,celeba/4506.jpg,True
1803,celeba/1803.jpg,True
4363,celeba/4363.jpg,True


In [258]:
neg_examples


Unnamed: 0,data,positive
5597,val2017/000000579307.jpg,False
5332,val2017/000000127517.jpg,False
5521,val2017/000000569565.jpg,False
8935,val2017/000000114770.jpg,False
9800,val2017/000000037751.jpg,False
...,...,...
9872,val2017/000000407524.jpg,False
7470,val2017/000000181969.jpg,False
9060,val2017/000000376310.jpg,False
9024,val2017/000000328601.jpg,False


In [259]:
a = torch.ones((2,2)).softmax(-1)
b = torch.ones((2,2)).softmax(-1)

def kl_divergence(p, q):
    return (p * (p/q).log()).sum(-1, keepdims=True)

kl_divergence(a, b).mean()

tensor(0.)

In [260]:

def compute_ablated_and_regular(df):
    paths = df["data"].tolist()
    images = [fetch_data(p) for p in paths]
    labels, probs = forward(images)


    layers = [layer for stage in model.resnet.encoder.stages for layer in stage.layers]
    to_ablate_flat_idxs = [ablate[f"layer{i}"] for i in range(16)]

    ablated_labels = None
    with AblateTorchModel(layers, to_ablate_flat_idxs, 0) as a:
        ablated_labels, ablated_probs = forward(images)
    

    return (labels, ablated_labels), (probs, ablated_probs)

pos_res = compute_ablated_and_regular(pos_examples)
neg_res = compute_ablated_and_regular(neg_examples)

In [261]:
def compute_acc(y, yhat):
	correct = y==yhat
	return correct.sum() / len(y)

print(compute_acc(torch.tensor(pos_res[0][0]), torch.tensor(pos_res[0][1])))
print(compute_acc(torch.tensor(neg_res[0][0]), torch.tensor(neg_res[0][1])))

print("avg kl divergence")
print(kl_divergence(pos_res[1][0], pos_res[1][1]).mean())
print(kl_divergence(neg_res[1][0], neg_res[1][1]).mean())

tensor(0.0200)
tensor(0.7700)
avg kl divergence
tensor(3.2989)
tensor(0.1946)
