In [None]:
from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import requests
import torch
import numpy as np
import torch
from pathlib import Path
import json
from typing import Dict
import itertools

In [None]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTModel.from_pretrained('google/vit-base-patch16-224')

# Let's change the model to get rid of the last layers
with torch.no_grad():
    print(model.pooler.dense.weight)
    eye = torch.from_numpy(np.eye(768).astype("float32"))
    model.pooler.dense.weight.data = eye
    print(model.pooler.dense.weight)
    
    print(model.pooler.dense.bias.shape)
    print(model.pooler.dense.bias.abs().sum())
    model.pooler.dense.bias *= 0
    print(model.pooler.dense.bias.abs().sum())

In [None]:

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
images = [Image.open(requests.get(url, stream=True).raw) for _ in range(2)]

DS_DIR = Path(".").absolute() / "dataset"
images = {int(f.name.replace(".jpg", "")): f for f in (DS_DIR / "cropped").glob("*.jpg")}
label_train = json.loads(
    (DS_DIR / "labels_train.json").read_text()
)
label_test = json.loads(
    (DS_DIR / "labels_test.json").read_text()
)
label_test = {int(k): v for k, v in label_test.items()}
label_train = {int(k): v for k, v in label_train.items()}

assert len(images) == len(label_train) + len(label_test) == 10_000
uniq_labels = set(label_test.values()).union(label_train.values())
label_map = {label: i for i, label in enumerate(uniq_labels)}

In [None]:
def get_features(
    labels: Dict[int, str],
    label_map: Dict[str, int],
    images: Dict[int, Path],
    B=32
):
    # reorganize dataset
    ids = list(sorted(labels.keys()))
    images = [images[k] for k in ids]
    labels = [labels[k] for k in ids]
    
    y = [label_map[l] for l in labels]

    X = []
    for k in itertools.count():
        print(k, flush=True)
        img_files = images[B * k : B * (k +1)]
        imgs = [Image.open(f) for f in img_files]
        if not len(imgs):
            break
        inputs = processor(images=imgs, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
        X.append(outputs["pooler_output"])
    return torch.concat(X), torch.from_numpy(np.array(y, dtype="int"))

X_test, y_test = get_features(label_test, label_map, images)
X_train, y_train = get_features(label_train, label_map, images)


In [None]:
assert len(X_test) == len(y_test)
assert len(X_train) == len(y_train)
assert len(X_train) + len(X_test) == 10_000

with open(DS_DIR / "embedding.pt", "wb") as f:
    torch.save({"X_test": X_test, "y_test": y_test, "X_train": X_train, "y_train": y_train}, f)

In [None]:
with open(DS_DIR / "embedding.pt", "rb") as f:
    data = torch.load(f)
{k: v.shape for k, v in  data.items()}