In [None]:

print("Upgrading pip, setuptools, wheel...")
!pip install -q --upgrade pip setuptools wheel
print("Finished upgrading pip, setuptools, wheel.")


print("Installing specific numpy version (1.23.5)...")
!pip install -q numpy==1.23.5 --force-reinstall --no-deps
print("Finished installing numpy.")

# Step 2: Install fashion_clip separately, without a specific version,
# to let pip find the most compatible one with the current Kaggle environment's PyTorch.
print("Installing fashion_clip...")
!pip install -q fashion_clip
print("Finished installing fashion_clip.")

# Step 3: Install remaining packages.
# Note: transformers version 4.30.2 might be old. If you face issues, try removing version constraint.
# scikit-learn 1.2.2 is also stable.
print("Installing remaining core packages (transformers, scikit-learn, pillow, opencv, tqdm)...")
!pip install -q transformers==4.30.2 scikit-learn==1.2.2 Pillow==9.4.0 opencv-python==4.7.0.72 tqdm
print("Finished installing remaining core packages.")


import os
import torch
from PIL import Image
import numpy as np
from sklearn.neighbors import NearestNeighbors # For similarity search
import matplotlib.pyplot as plt
from IPython.display import Image as DisplayImage, display
import glob # For finding image files
from tqdm.notebook import tqdm # For progress bars
import json # For saving/loading metadata
import joblib # For saving/loading the NearestNeighbors model

# Import models and processors
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from fashion_clip.fashion_clip import FashionCLIP # This should now work if fashion_clip installed

print("All necessary libraries installed and imported.")
print(f"PyTorch version: {torch.__version__}")
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device name: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")

# Set a working directory for saving processed files (NearestNeighbors model, metadata)
WORKING_DIR = "/kaggle/working"
os.makedirs(WORKING_DIR, exist_ok=True)
print(f"Working directory set to: {WORKING_DIR}")

# Define the root path to your dataset
DATASET_ROOT_PATH = "/kaggle/input/clothestry/clothes_tryon_dataset"
print(f"Dataset root path set to: {DATASET_ROOT_PATH}")

Upgrading pip, setuptools, wheel...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m23.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m44.9 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.8.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
pandas-gbq 0.29.1 requires google-api-core<3.0.0,>=2.10.2, but you have google-api-core 1.34.1 which is incompatible.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.
bigframes 2.8.0 requires google-cloud-bigquery[bqstorage,pandas]>=3.31.0, but you have google-cloud-bigquery 3.25.0 which is incompatible.
bigframes 2.8.0 requires rich<14,>=12.4.4, but you have rich 14.0.0 which is incompatible.[0m

2025-07-16 17:13:31.751373: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752686012.102060      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752686012.206466      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


All necessary libraries installed and imported.
PyTorch version: 2.6.0+cu124
Is CUDA available: True
CUDA device name: Tesla T4
Working directory set to: /kaggle/working
Dataset root path set to: /kaggle/input/clothestry/clothes_tryon_dataset


In [2]:
# Cell 2: Model Loading

# Load FashionCLIP model
print("Loading FashionCLIP model...")
fclip = FashionCLIP('fashion-clip')
fclip.eval() # Set to evaluation mode
print("FashionCLIP model loaded.")

# Load SegFormer model for semantic segmentation
print("Loading SegFormer model...")
processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model_segformer = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")

# Move models to GPU if available
if torch.cuda.is_available():
    fclip.to('cuda')
    model_segformer.to('cuda')
    print("Models moved to GPU.")
else:
    print("CUDA not available. Models running on CPU (may be slower).")

model_segformer.eval() # Set to evaluation mode
print("SegFormer model loaded.")

Loading FashionCLIP model...




config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/568 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

AttributeError: 'FashionCLIP' object has no attribute 'eval'

In [None]:

PERSON_CLASS_ID_ADE20K = 13
PADDING_PIXELS = 10 
def segment_and_crop(image_pil: Image.Image, seg_processor, seg_model, target_class_id: int = None) -> Image.Image:
    """
    Segments an image using SegFormer and crops to the bounding box of the target class.
    If target_class_id is None, it attempts to find the largest non-background segment.
    """
    inputs = seg_processor(images=image_pil, return_tensors="pt")
    if torch.cuda.is_available():
        inputs = {k: v.to('cuda') for k, v in inputs.items()}

    with torch.no_grad():
        outputs = seg_model(**inputs)
        logits = outputs.logits.cpu() # Move logits back to CPU for numpy operations

    # Resize logits to original image size
    upsampled_logits = torch.nn.functional.interpolate(
        logits,
        size=image_pil.size[::-1], # (height, width)
        mode="bilinear",
        align_corners=False,
    )
    pred_mask = upsampled_logits.argmax(dim=1)[0] # Get the predicted class mask

    # Try to find the target class or the largest non-background segment
    if target_class_id is None:
        unique_classes, counts = torch.unique(pred_mask, return_counts=True)
        # Exclude background (class 0 often)
        non_background_classes = unique_classes[unique_classes != 0]
        if len(non_background_classes) > 0:
            # Find the most frequent non-background class
            max_count_idx = torch.argmax(counts[unique_classes != 0])
            target_class_id = non_background_classes[max_count_idx].item()
        else:
            # Fallback if no foreground object is detected
            # print("Warning: No significant foreground class found. Returning original image.")
            return image_pil

    binary_mask = (pred_mask == target_class_id).numpy().astype(np.uint8)

    # Find bounding box
    coords = np.argwhere(binary_mask)
    if coords.size == 0:
        # print(f"Warning: No pixels found for target class {target_class_id}. Returning original image.")
        return image_pil

    y_min, x_min = coords.min(axis=0)
    y_max, x_max = coords.max(axis=0)

    # Add padding
    x_min = max(0, x_min - PADDING_PIXELS)
    y_min = max(0, y_min - PADDING_PIXELS)
    x_max = min(image_pil.width, x_max + PADDING_PIXELS)
    y_max = min(image_pil.height, y_max + PADDING_PIXELS)

    cropped_image = image_pil.crop((x_min, y_min, x_max, y_max))
    return cropped_image

def encode_image(image_pil: Image.Image, fclip_model: FashionCLIP) -> np.ndarray:
    """
    Encodes a PIL Image into a FashionCLIP embedding.
    """
    with torch.no_grad():
        embedding = fclip_model.encode_images([image_pil])[0]
    return embedding.cpu().numpy() # Ensure it's on CPU and converted to numpy

print("Segmentation and embedding helper functions defined.")

In [None]:
# Cell 4: Collect Image Paths from Dataset

image_paths = []

# Paths to the 'cloth' directories within train and test
train_cloth_path = os.path.join(DATASET_ROOT_PATH, "train", "cloth")
test_cloth_path = os.path.join(DATASET_ROOT_PATH, "test", "cloth")

print(f"Collecting images from: {train_cloth_path} and {test_cloth_path}")

# Search for common image extensions in both train/cloth and test/cloth
for ext in ['*.jpg', '*.jpeg', '*.png']:
    image_paths.extend(glob.glob(os.path.join(train_cloth_path, ext), recursive=False)) # No recursive needed, files are direct
    image_paths.extend(glob.glob(os.path.join(test_cloth_path, ext), recursive=False)) # No recursive needed

if not image_paths:
    print(f"No images found in '{train_cloth_path}' or '{test_cloth_path}'.")
    print("Please verify the DATASET_ROOT_PATH and the internal structure of the dataset.")
    raise RuntimeError("No images found for processing. Check data path and structure.")

print(f"Found {len(image_paths)} images for indexing.")

# Verify a cropped image from your dataset
if image_paths:
    # Use a try-except to handle potential issues with the first image
    try:
        sample_img_pil = Image.open(image_paths[0]).convert("RGB")
        cropped_sample_img = segment_and_crop(sample_img_pil, processor, model_segformer, target_class_id=PERSON_CLASS_ID_ADE20K)
        print(f"\nOriginal image size: {sample_img_pil.size}, Cropped image size: {cropped_sample_img.size}")
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(sample_img_pil)
        plt.title("Original Image")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(cropped_sample_img)
        plt.title("SegFormer Cropped Image")
        plt.axis('off')
        plt.show()
    except Exception as e:
        print(f"Error displaying sample image: {e}")
        print("This might happen if the first image is corrupted or unreadable.")

In [None]:
# Cell 5: Vector Database Setup (FAISS) and Data Ingestion

# Define embedding dimension (FashionCLIP outputs 512-dim vectors)
EMBEDDING_DIM = 512
index = faiss.IndexFlatIP(EMBEDDING_DIM) # Using Inner Product for cosine similarity with normalized vectors

# To store metadata (like image paths) as FAISS only stores vectors
image_metadata = [] # List of dictionaries: {"id": "img_id", "path": "image_path"}

print("Starting data ingestion into FAISS index...")
# Use tqdm for a progress bar
# Limit to a smaller number for quick testing in Kaggle if 8000 is too slow
# For full dataset, remove `[:8000]` or adjust.
# For 8000 images, this might take ~15-30 minutes on a GPU.
processing_images = image_paths # if you want to process all
# processing_images = image_paths[:2000] # Uncomment to process a subset for faster demo

for i, img_path in enumerate(tqdm(processing_images, desc="Processing images")):
    try:
        # Use a unique ID for each image, e.g., relative path from dataset root or just filename
        img_id = os.path.relpath(img_path, DATASET_ROOT_PATH) # Use relative path as ID
        img_pil = Image.open(img_path).convert("RGB")

        # Step 1: Segment and crop the image
        cropped_img_pil = segment_and_crop(img_pil, processor, model_segformer, target_class_id=PERSON_CLASS_ID_ADE20K)

        # Step 2: Generate FashionCLIP embedding
        embedding = encode_image(cropped_img_pil, fclip)

        # Ensure embedding is L2 normalized for cosine similarity (which is dot product after normalization)
        embedding = embedding / np.linalg.norm(embedding)

        # Add to FAISS index
        index.add(np.array([embedding], dtype=np.float32))

        # Store metadata
        image_metadata.append({"id": img_id, "path": img_path})

    except Exception as e:
        # Print errors but don't stop the whole process for one bad image
        print(f"\nError processing {img_path}: {e}")

print(f"\nData ingestion complete. Total vectors in FAISS: {index.ntotal}")

# Save the FAISS index and metadata for persistence
faiss_index_path = os.path.join(WORKING_DIR, "tshirt_faiss_index.bin")
faiss.write_index(index, faiss_index_path)
print(f"FAISS index saved to {faiss_index_path}")

import json
metadata_path = os.path.join(WORKING_DIR, "tshirt_metadata.json")
with open(metadata_path, 'w') as f:
    json.dump(image_metadata, f)
print(f"Metadata saved to {metadata_path}")

In [None]:

print("Loading FAISS index and metadata for query function...")
loaded_index = faiss.read_index(os.path.join(WORKING_DIR, "tshirt_faiss_index.bin"))
with open(os.path.join(WORKING_DIR, "tshirt_metadata.json"), 'r') as f:
    loaded_image_metadata = json.load(f)
print(f"Loaded FAISS index with {loaded_index.ntotal} vectors and {len(loaded_image_metadata)} metadata entries.")


def get_recommendations(query_image_path: str, top_k: int = 10) -> list:
    """
    Takes a query image path, processes it, queries the FAISS index,
    and returns top_k similar images with their scores.
    """
    print(f"\nProcessing query image: {query_image_path}")
    try:
        query_img_pil = Image.open(query_image_path).convert("RGB")

        cropped_query_img = segment_and_crop(query_img_pil, processor, model_segformer, target_class_id=PERSON_CLASS_ID_ADE20K)

        query_embedding = encode_image(cropped_query_img, fclip)

        query_embedding = query_embedding / np.linalg.norm(query_embedding)

        
        D, I = loaded_index.search(np.array([query_embedding], dtype=np.float32), top_k)

        recommendations = []
        # D[0] are scores, I[0] are indices
        for rank, (score, idx) in enumerate(zip(D[0], I[0])):
            if idx < len(loaded_image_metadata): # Ensure index is valid
                recommended_item = loaded_image_metadata[idx]
                recommendations.append({
                    "rank": rank + 1,
                    "id": recommended_item["id"],
                    "image_path": recommended_item["path"],
                    "similarity_score": float(score)
                })
        return recommendations

    except Exception as e:
        print(f"Error getting recommendations for {query_image_path}: {e}")
        return []

print("Recommendation query function defined.")

In [None]:
# Cell 7: Simulate Web Interface - User Query and Display Results

import random

# Pick a random image from your processed dataset to use as a query
# We use the full image_paths list, not just the subset processed if you used `processing_images`
if image_paths:
    query_image_path_example = random.choice(image_paths)
    print(f"Simulating user uploading: {query_image_path_example}")
else:
    raise RuntimeError("No images available to use as a query. Ensure data ingestion worked.")

# Display the uploaded (query) image
print("\n--- Query Image ---")
display(DisplayImage(query_image_path_example, width=300))

# Get recommendations
recommendations = get_recommendations(query_image_path_example, top_k=10)

print(f"\n--- Top {len(recommendations)} Recommendations ---")

if recommendations:
    # Set up matplotlib for displaying images in a grid
    # Filter out the query image itself if it's an identical match from the dataset
    filtered_recommendations = [rec for rec in recommendations if rec['image_path'] != query_image_path_example]

    if len(filtered_recommendations) < 10 and len(recommendations) > 0 and len(filtered_recommendations) != len(recommendations):
        print(f"Note: Query image was among top results and filtered out. Showing {len(filtered_recommendations)} unique recommendations.")
    elif not filtered_recommendations:
        print("No unique recommendations found after filtering (perhaps all top matches were the query image).")

    display_recs = filtered_recommendations if filtered_recommendations else recommendations # Show filtered, but if all filtered, show original

    if not display_recs:
        print("No recommendations to display.")
    else:
        num_cols = 5
        num_rows = (len(display_recs) + num_cols - 1) // num_cols
        
        fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 3, num_rows * 4))
        if num_rows == 1 and num_cols == 1: # Handle single subplot case
            axes = np.array([axes]) # Make it iterable
        axes = axes.flatten() # Flatten for easy iteration

        for i, rec in enumerate(display_recs):
            if i >= len(axes): # Prevent index out of bounds if few recommendations
                break
            try:
                img = Image.open(rec["image_path"]).convert("RGB")
                axes[i].imshow(img)
                # Split ID if it contains path separators for cleaner display
                display_id = rec['id'].split(os.path.sep)[-1]
                axes[i].set_title(f"Score: {rec['similarity_score']:.3f}\nID: {display_id}")
                axes[i].axis('off')
            except Exception as e:
                print(f"Could not load recommendation image {rec['image_path']}: {e}")
                axes[i].text(0.5, 0.5, "Image Error", horizontalalignment='center', verticalalignment='center', transform=axes[i].transAxes)
                axes[i].axis('off')

        # Hide any unused subplots
        for j in range(len(display_recs), len(axes)):
            fig.delaxes(axes[j])

        plt.tight_layout()
        plt.show()
else:
    print("No recommendations found.")

print("\nRecommender system demo complete!")