In [1]:
import os
import numpy as np
import torch
import open_clip
from PIL import Image
from torchvision import transforms
import faiss

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]



In [2]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')
model.eval()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-11): 12 x ResidualAttentionBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine

In [3]:
image_dir = '../data/processed'  # Update path if necessary
image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.lower().endswith(('jpg', 'png'))]
print(f"Found {len(image_paths)} images.")

Found 24975 images.


In [None]:
embeddings = []
metadata = []

for path in image_paths:
    try:
        image = Image.open(path).convert('RGB')
        image_input = preprocess(image).unsqueeze(0)

        with torch.no_grad():
            image_features = model.encode_image(image_input).cpu().numpy()

        embeddings.append(image_features)
        metadata.append(path)
    except Exception as e:
        print(f"Error processing {path}: {e}")

In [None]:
embedding_matrix = np.concatenate(embeddings, axis=0).astype('float32')

index = faiss.IndexFlatL2(embedding_matrix.shape[1])
index.add(embedding_matrix)

os.makedirs('embeddings', exist_ok=True)
faiss.write_index(index, 'embeddings/image.index')
np.save('embeddings/metadata.npy', np.array(metadata))

print("Index and metadata saved.")