In [3]:
# batch processing 
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import clip
import pandas as pd
import numpy as np
from tqdm import tqdm
from pathlib import Path

IMG_PATH = '../data/train_images/'


df = pd.read_csv('../data/raw/train.csv')

In [64]:
torch.cuda.is_available()

False

In [6]:
# Assuming you have a GPU available
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

class ImageDataset(Dataset):
    def __init__(self, df, transform):
        self.image_paths = df["image"].values
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(Path(IMG_PATH, f"{img_path}.jpg")).convert("RGB")
        except Exception as e:
            # print(f"Error loading image {img_path}: {e}")
            # create PIL empty image
            image = Image.new("RGB", (224, 224), (0, 0, 0))

        return self.transform(image), img_path


def collate_fn(batch):
    images, paths = zip(*batch)
    return torch.stack(images), paths

In [8]:
dataset = ImageDataset(df, preprocess)
dataloader = DataLoader(dataset, batch_size=64, num_workers=8, shuffle=False, collate_fn=collate_fn)

# Extract embeddings
embeddings = {}
iter_id = 0


with torch.no_grad():
    for images, paths in tqdm(dataloader):
        images = images.to(device)
        image_features = model.encode_image(images)
        image_features = image_features.cpu().numpy()
        embeddings[paths] = image_features
        iter_id += 1
        # if iter_id > 10:
        #     break

100%|██████████| 23491/23491 [15:46<00:00, 24.81it/s]


In [11]:
# check dict size
len(embeddings.keys())

23491

In [15]:
# get flattened embeddings
# currently each key contains list of file ids 
# and each value contains list of embeddings

new_embeddings = {}
for key, value in tqdm(embeddings.items()):
    for k, v in zip(key, value):
        new_embeddings[k] = v

100%|██████████| 23491/23491 [00:00<00:00, 63814.26it/s]


In [16]:
len(new_embeddings.keys())

1503424

In [19]:
np.savez_compressed('features.npz', embeddings=list(new_embeddings.values()), images=list(new_embeddings.keys()))

In [29]:
!mv features.npz ../outputs/data/feat/clip_embeddings.npz