importing libraries

In [3]:
import os
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import clip
import numpy as np
import torch
import clip
from PIL import Image
import os
from sklearn.metrics.pairwise import cosine_similarity

In [4]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

Using device: mps


Loading clip model

In [5]:
# Load CLIP model
model_name = "ViT-B/32"
model, preprocess = clip.load(model_name, device=device)
model.eval()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [6]:
# Directories
image_folder = "/Users/rahul/Downloads/deepfashion1_data/images/"
segm_folder = "/Users/rahul/Downloads/deepfashion1_data/segm/"

preprocessing of images using segmentation mask

In [7]:
# A utility function to apply segmentation mask
def apply_mask(original_img_path, segm_mask_path):
    img = Image.open(original_img_path).convert('RGB')
    img_np = np.array(img)
    
    mask = Image.open(segm_mask_path).convert('L')
    mask = mask.resize(img.size, Image.NEAREST)
    mask_np = np.array(mask)

    # Binary mask where clothing pixels > 0
    binary_mask = (mask_np > 0).astype(np.uint8)

    # Black out non-clothing pixels
    img_np[binary_mask == 0] = 0

    masked_img = Image.fromarray(img_np)
    return masked_img

# Get list of images (assuming all have corresponding segm masks)
image_files = [f for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

storing embedding of all images

In [17]:
# Prepare arrays/lists to store data
all_embeddings = []
all_image_paths = []

with torch.no_grad():
    for img_name in tqdm(image_files, desc="Processing dataset"):
        # Construct paths
        base_name, ext = os.path.splitext(img_name)
        segm_name = base_name + "_segm.png"  
        segm_path = os.path.join(segm_folder, segm_name)
        img_path = os.path.join(image_folder, img_name)

        # Check if mask exists
        if not os.path.exists(segm_path):
            masked_img = Image.open(img_path).convert('RGB')
        else:
            masked_img = apply_mask(img_path, segm_path)

        # Preprocess and encode the masked image
        image_input = preprocess(masked_img).unsqueeze(0).to(device)
        image_feature = model.encode_image(image_input)
        image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)

        # Convert to numpy and store
        embedding_np = image_feature.cpu().numpy()
        all_embeddings.append(embedding_np)
        
        # Store full path directly here
        all_image_paths.append(img_path)

# Stack all embeddings into a single NumPy array: shape (num_images, embed_dim)
all_embeddings = np.vstack(all_embeddings)

# Save embeddings and corresponding image paths for future use
np.save("masked_image_embeddings.npy", all_embeddings)
print("Processing complete. Embeddings saved to 'masked_image_embeddings.npy'.")

Processing dataset: 100%|██████████| 44096/44096 [8:38:50<00:00,  1.42it/s]     


Processing complete. Embeddings saved to 'masked_image_embeddings.npy'.


In [34]:
# all_image_paths

loading image saved file and embeddings

In [20]:
with open("masked_image_paths.txt", "w") as f:
    for img_name in all_image_paths:
        full_path = os.path.join(image_folder, img_name)
        f.write(full_path + "\n")

In [8]:
# Load embeddings
all_embeddings = np.load("masked_image_embeddings.npy")  # shape: (num_images, embed_dim)

# Load image paths
with open("masked_image_paths.txt", "r") as f:
    all_image_paths = [line.strip() for line in f.readlines()]

# all_image_paths = [os.path.join(image_folder, p) for p in all_image_paths]


print(f"Loaded {len(all_image_paths)} embeddings and image paths.")

Loaded 44096 embeddings and image paths.


In [9]:
def get_image_embedding(img_path):
    img = Image.open(img_path).convert('RGB')
    image_input = preprocess(img).unsqueeze(0).to(device)
    with torch.no_grad():
        image_feature = model.encode_image(image_input)
    image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
    return image_feature.cpu().numpy()  # (1, embed_dim)

def get_text_embedding(user_text):
    text_tokens = clip.tokenize([user_text]).to(device)
    with torch.no_grad():
        text_feature = model.encode_text(text_tokens)
    text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
    return text_feature.cpu().numpy()  # (1, embed_dim)

In [10]:
def combine_embeddings(image_embed, text_embed, alpha):
    # alpha for text, (1 - alpha) for image
    combined = (alpha * text_embed) + ((1 - alpha) * image_embed)
    combined = combined / np.linalg.norm(combined, axis=1, keepdims=True)
    return combined

In [11]:
def find_similar_images(query_embedding, dataset_embeddings, dataset_paths, top_k):
    # query_embedding: (1, embed_dim)
    similarities = cosine_similarity(query_embedding, dataset_embeddings)  # (1, num_images)
    similarities = similarities.flatten()
    sorted_indices = np.argsort(similarities)[::-1]
    top_indices = sorted_indices[:top_k]
    results = [(dataset_paths[i], similarities[i]) for i in top_indices]
    return results

image similarity only

In [32]:
user_image_path = "/Users/rahul/Downloads/deepfashion1_data/images/MEN-Jackets_Vests-id_00005097-05_1_front.jpg"
user_image_embed = get_image_embedding(user_image_path)
results = find_similar_images(user_image_embed, all_embeddings, all_image_paths, top_k=10)

print("Top matches for image-only query:")
for path, score in results:
    print(f"{path} - similarity: {score:.4f}")

Top matches for image-only query:
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Jackets_Vests-id_00005097-05_1_front.jpg - similarity: 1.0000
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Jackets_Vests-id_00005528-06_1_front.jpg - similarity: 0.9541
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Jackets_Vests-id_00005097-08_4_full.jpg - similarity: 0.9450
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Jackets_Vests-id_00007583-01_1_front.jpg - similarity: 0.9410
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Jackets_Vests-id_00007564-01_2_side.jpg - similarity: 0.9391
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Jackets_Vests-id_00004235-01_1_front.jpg - similarity: 0.9391
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Jackets_Vests-id_00007564-01_1_front.jpg - similarity: 0.9378
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Jackets_Vests-id_00007564-03_1_front.jpg - similarity: 0.9314
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Jack

In [13]:
# user_text = "Hi, can you please provide me some men shorts of any colors"
# user_text_embed = get_text_embedding(user_text)
# results = find_similar_images(user_text_embed, all_embeddings, all_image_paths, top_k=10)

# print("Top matches for text-only query:")
# for path, score in results:
#     print(f"{path} - similarity: {score:.4f}")

image and text input from user and finding similarities

In [44]:
user_image_path = "/Users/rahul/Downloads/deepfashion1_data/images/MEN-Shirts_Polos-id_00001802-02_1_front.jpg"
user_text = "Show me professional outfits like this, but with a neutral blazer with full sleeves for my sister."
user_image_embed = get_image_embedding(user_image_path)
user_text_embed = get_text_embedding(user_text)

# First pass: equal weighting
combined_embed = combine_embeddings(user_image_embed, user_text_embed, alpha=0.7)
initial_results = find_similar_images(combined_embed, all_embeddings, all_image_paths, top_k=10)

print("Initial combined results (equal weight):")
for path, score in initial_results:
    print(f"{path} - similarity: {score:.4f}")

# Optional second pass: rerank with heavier text weight if desired
rerank_alpha = 0.9
rerank_embed = combine_embeddings(user_image_embed, user_text_embed, alpha=rerank_alpha)

# Just re-check top candidates from initial_results
candidate_paths = [res[0] for res in initial_results]
candidate_embs = []
for cpath in candidate_paths:
    idx = all_image_paths.index(cpath)
    candidate_embs.append(all_embeddings[idx][None, :])
candidate_embs = np.vstack(candidate_embs)

similarities = cosine_similarity(rerank_embed, candidate_embs).flatten()
sorted_indices = np.argsort(similarities)[::-1]
final_results = [(candidate_paths[i], similarities[i]) for i in sorted_indices]

print("Reranked results (more text weight):")
for path, score in final_results:
    print(f"{path} - similarity: {score:.4f}")

Initial combined results (equal weight):
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Shirts_Polos-id_00001802-02_1_front.jpg - similarity: 0.5813
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Sweaters-id_00001783-01_4_full.jpg - similarity: 0.5603
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Sweaters-id_00005021-01_1_front.jpg - similarity: 0.5602
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Jackets_Vests-id_00001903-01_1_front.jpg - similarity: 0.5577
/Users/rahul/Downloads/deepfashion1_data/images/WOMEN-Jackets_Coats-id_00001705-03_7_additional.jpg - similarity: 0.5572
/Users/rahul/Downloads/deepfashion1_data/images/WOMEN-Jackets_Coats-id_00001705-03_2_side.jpg - similarity: 0.5563
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Jackets_Vests-id_00001903-03_1_front.jpg - similarity: 0.5552
/Users/rahul/Downloads/deepfashion1_data/images/MEN-Sweaters-id_00001008-05_1_front.jpg - similarity: 0.5538
/Users/rahul/Downloads/deepfashion1_data/images/WOMEN-Ja