In [3]:
from transformers import AutoImageProcessor, ResNetForImageClassification
import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = AutoImageProcessor.from_pretrained("microsoft/resnet-34")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-34")
model = model.to(DEVICE)

class ActivationTracker:
    def __init__(self, layers):
        self.activations = {}
        self.layers = layers
        self.hooks = []
        for l in layers:
            self.hooks.append(
                self.register_hook(l)
            )

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks = []
        self.activations = []

    def register_hook(self, layer):
        def hook(module, inputs, outputs):
            self.activations[layer] = outputs.detach().cpu()
            return outputs
        return layer.register_forward_hook(hook)
    
    def shapes(self):
        shapes = []
        for l in self.layers:
            a = self.activations[l]
            shapes.append(list(a.shape))
        return shapes
    
    def flat_activations(self):
        data = []
        for l in self.layers:
            data.append(self.activations[l].flatten())
        return torch.concat(data)
    
    def activations_inorder(self):
        data = []
        for l in self.layers:
            data.append(self.activations[l])
        return data
    
    def __enter__(self):
        return self
    def __exit__(self, exception_type, exception_value, exception_traceback):
        self.remove_hooks()

preprocessor_config.json:   0%|          | 0.00/266 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/69.5k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]

In [5]:
from PIL import Image

@torch.no_grad()
def forward(image):
    model.eval()
    inputs = processor(image, return_tensors="pt").to(DEVICE)
    logits = model(**inputs).logits
    labels = logits.argmax(-1).cpu().tolist()
    return labels
fetch_data = lambda path: Image.open(f"./data/{path}").convert("RGB")

In [6]:
import pandas as pd
df = pd.read_parquet("./data/tasks/face_task2k.parquet")
df

Unnamed: 0,data,positive
0,celeba/398.jpg,True
1,celeba/3833.jpg,True
2,celeba/4836.jpg,True
3,celeba/4572.jpg,True
4,celeba/636.jpg,True
...,...,...
1995,val2017/000000294855.jpg,False
1996,val2017/000000142790.jpg,False
1997,val2017/000000297085.jpg,False
1998,val2017/000000140286.jpg,False


In [7]:
from tqdm.notebook import tqdm

batch_size = 32
layers = [layer for stage in model.resnet.encoder.stages for layer in stage.layers]

labels_col = []
acts_col = []
with ActivationTracker(layers) as t:
	for i in tqdm(range(0, len(df), batch_size)):
		batch = df.iloc[i:i+batch_size]
		inputs = batch["data"].apply(fetch_data).tolist()
		labels = forward(inputs)
		labels_col.extend(labels)
		acts_col.append(t.activations_inorder())

  0%|          | 0/63 [00:00<?, ?it/s]

In [8]:
def stack_layer_activations(acts: list[torch.Tensor]):
	results = []
	for j in range(len(acts[0])):
		col = [t[j] for t in acts]
		pooled = torch.vstack(col)
		results.append(pooled)
	return results

stacked = stack_layer_activations(acts_col)
del acts_col

In [9]:
from safetensors.torch import save_file
def export_safetensors(layers_tensors, name):
	export_dict = {}
	for i, l in enumerate(layers_tensors):
		export_dict[f"layer{i}"] = l
	save_file(export_dict, name)

export_safetensors(stacked,"./data/tasks/face_task2k.safetensors")

In [13]:
df["label"] = labels_col

In [15]:
df.to_parquet("./data/tasks/face_task2k_with_labels.parquet", index=False)