In [1]:
# --- Cell 1: Setup and Imports ---

print("Importing libraries...")
import os
import numpy as np
import pickle
from tqdm import tqdm
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

# Suppress TensorFlow informational messages
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.layers import GlobalMaxPooling2D
from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm

print("--> Libraries imported successfully.")

Importing libraries...


2025-07-19 11:15:48.902104: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752923749.270120      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752923749.372478      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


--> Libraries imported successfully.


In [2]:
# --- Cell 2: Define Paths and Load Core Models ---

print("Defining file paths and loading models...")

# Define File Paths
BASE_PATH = '/kaggle/input/clothestry/clothes_tryon_dataset/'
# We no longer need the query image path, but we still need the database path
CLOTH_DATABASE_PATH = os.path.join(BASE_PATH, 'train', 'cloth')
OUTPUT_PATH = '/kaggle/working/'

# Load TFLite Segmentation Model
try:
    tflite_model_path = '/kaggle/input/deeplabv3-xception65/tflite/ade20k/2/2.tflite'
    segmentation_interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
    segmentation_interpreter.allocate_tensors()
    segmentation_input_details = segmentation_interpreter.get_input_details()
    segmentation_output_details = segmentation_interpreter.get_output_details()
    print("--> TFLite Segmentation model loaded successfully.")
except Exception as e:
    print(f"!!! FATAL ERROR loading TFLite segmentation model: {e}")
    raise

# Load Feature Extraction Model (ResNet50)
try:
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    base_model.trainable = False
    model_feature_extractor = tf.keras.Sequential([base_model, GlobalMaxPooling2D()])
    print("--> Feature extraction model (ResNet50) created.")
except Exception as e:
    print(f"!!! FATAL ERROR creating feature extraction model: {e}")
    raise

Defining file paths and loading models...
--> TFLite Segmentation model loaded successfully.


INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
I0000 00:00:1752923770.040931      36 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1752923770.041665      36 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
--> Feature extraction model (ResNet50) created.


In [8]:
# --- Cell 3: Define Helper Functions (with Maximum Padding) ---

print("Defining helper functions...")

# This function uses the ResNet50 model and does not need changes.
def extract_features(img_pil, model):
    """Takes a PIL Image and extracts a normalized feature vector."""
    try:
        img_resized = img_pil.resize((224, 224))
        img_array = np.array(img_resized)
        expand_img = np.expand_dims(img_array, axis=0)
        pre_img = preprocess_input(expand_img)
        result = model.predict(pre_img, verbose=0).flatten()
        return result / norm(result)
    except Exception as e:
        print(f"Error during feature extraction: {e}")
        return None

# This function is modified to have a much larger crop area.
def segment_and_crop_garment(image_path, interpreter, input_details, output_details):
    """Detects clothing using a TFLite interpreter and returns a cropped PIL Image."""
    original_image_pil = Image.open(image_path).convert("RGB")
    original_image_np = np.array(original_image_pil)
    
    _, input_height, input_width, _ = input_details[0]['shape']
    resized_image = cv2.resize(original_image_np, (input_width, input_height))
    input_data = np.expand_dims(resized_image, axis=0).astype(np.float32)

    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    segmentation_mask = interpreter.get_tensor(output_details[0]['index'])[0]
    segmentation_mask = np.argmax(segmentation_mask, axis=-1)
    
    top_classes = [13] 
    
    binary_mask = np.zeros_like(segmentation_mask, dtype=bool)
    for class_id in top_classes:
        binary_mask = binary_mask | (segmentation_mask == class_id)
        
    mask_resized_pil = Image.fromarray(binary_mask.astype(np.uint8) * 255)
    mask_resized_pil = mask_resized_pil.resize(original_image_pil.size, Image.NEAREST)
    mask_resized = np.array(mask_resized_pil).astype(bool)
    
    if not np.any(mask_resized): return None

    where = np.where(mask_resized)
    (ymin, ymax), (xmin, xmax) = (np.min(where[0]), np.max(where[0])), (np.min(where[1]), np.max(where[1]))
    
    # !!! MODIFICATION: Increased padding for a much larger crop area !!!
    padding = 60
    
    cropped_image = original_image_pil.crop((
        max(0, xmin - padding), max(0, ymin - padding),
        min(original_image_np.shape[1], xmax + padding), min(original_image_np.shape[0], ymax + padding)
    ))
    return cropped_image

print("--> Helper functions defined with a larger crop area (padding=60).")

Defining helper functions...
--> Helper functions defined with a larger crop area (padding=60).


In [5]:
# --- NEW Cell 4: Create and Save the Feature Database ---

# Define the paths for the output files
feature_list_path = os.path.join(OUTPUT_PATH, 'embedding.pkl')
filenames_path = os.path.join(OUTPUT_PATH, 'filenames.pkl')

# This 'if' statement is key: it checks if the files already exist in this session.
if not os.path.exists(feature_list_path) or not os.path.exists(filenames_path):
    print("Feature database not found in this session. Creating a new one...")
    print("(This is a one-time process per session and may take 15-30 minutes.)")
    
    # Get all the file paths from your clothing database
    cloth_filenames = [os.path.join(CLOTH_DATABASE_PATH, f) for f in sorted(os.listdir(CLOTH_DATABASE_PATH))]
    feature_list = []

    # Loop through all files and extract features, showing a progress bar
    for filename in tqdm(cloth_filenames, desc="Extracting Features"):
        try:
            img_pil = Image.open(filename).convert("RGB")
            features = extract_features(img_pil, model_feature_extractor)
            if features is not None:
                feature_list.append(features)
        except Exception as e:
            print(f"Skipping file {filename} due to error: {e}")
    
    # Save the completed lists to disk
    with open(feature_list_path, 'wb') as f:
        pickle.dump(feature_list, f)
    with open(filenames_path, 'wb') as f:
        pickle.dump(cloth_filenames, f)
        
    print(f"\n--> Database created and saved successfully to '{OUTPUT_PATH}'")
else:
    print(f"--> Found existing feature database in this session. Skipping creation.")

Feature database not found in this session. Creating a new one...
(This is a one-time process per session and may take 15-30 minutes.)


I0000 00:00:1752923939.193054     100 service.cc:148] XLA service 0x7950f4002650 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1752923939.194528     100 service.cc:156]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1752923939.194549     100 service.cc:156]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1752923939.919526     100 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1752923943.085244     100 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
Extracting Features: 100%|██████████| 11647/11647 [21:16<00:00,  9.13it/s]



--> Database created and saved successfully to '/kaggle/working/'


In [6]:
# --- Cell 4: Load Pre-computed Database and Initialize Recommenders ---

print("Loading feature database for recommendation...")
feature_list_path = os.path.join(OUTPUT_PATH, 'embedding.pkl')
filenames_path = os.path.join(OUTPUT_PATH, 'filenames.pkl')

try:
    feature_list = np.array(pickle.load(open(feature_list_path, 'rb')))
    filenames = pickle.load(open(filenames_path, 'rb'))
    print(f"--> Database with {len(feature_list)} embeddings loaded.")
except FileNotFoundError:
    print("!!! ERROR: Database files (embedding.pkl, filenames.pkl) not found.")
    print("!!! Please run the feature extraction notebook/cell first to create them.")
    raise
except Exception as e:
    print(f"!!! FATAL ERROR loading database files: {e}")
    raise

# Initialize TWO Recommenders with Different Metrics for 10 neighbors
neighbors_euclidean = NearestNeighbors(n_neighbors=10, algorithm='brute', metric='euclidean')
neighbors_euclidean.fit(feature_list)
print("--> Euclidean Recommender is ready.")

neighbors_cosine = NearestNeighbors(n_neighbors=10, algorithm='brute', metric='cosine')
neighbors_cosine.fit(feature_list)
print("--> Cosine Recommender is ready.")

Loading feature database for recommendation...
--> Database with 11647 embeddings loaded.
--> Euclidean Recommender is ready.
--> Cosine Recommender is ready.


In [9]:
# --- Cell 5: The Interactive Upload and Recommendation Cell ---

# This is the main function that runs when you upload an image
def find_recommendations(uploaded_file_data):
    # --- Pipeline Step 1: Save the uploaded file temporarily ---
    file_name = uploaded_file_data['name']
    file_content = uploaded_file_data['content']
    temp_path = f"./{file_name}"
    with open(temp_path, 'wb') as f:
        f.write(file_content)

    # --- Pipeline Step 2: Segment the garment from the uploaded image ---
    cropped_garment = segment_and_crop_garment(
        temp_path, 
        segmentation_interpreter, 
        segmentation_input_details, 
        segmentation_output_details
    )
    
    # Clean up the temporary file immediately
    os.remove(temp_path)

    # --- Pipeline Step 3: Run the rest of the pipeline ---
    if cropped_garment:
        print("1. Garment detected and cropped.")
        query_features = extract_features(cropped_garment, model_feature_extractor)
        
        if query_features is not None:
            print("2. Features extracted from cropped garment.")
            
            # Display Query Analysis
            fig, axes = plt.subplots(1, 2, figsize=(8, 4))
            axes[0].imshow(Image.open(io.BytesIO(file_content))); axes[0].set_title("Your Upload"); axes[0].axis('off')
            axes[1].imshow(cropped_garment); axes[1].set_title("Auto-Cropped Garment"); axes[1].axis('off')
            plt.suptitle("Query Analysis", fontsize=16); plt.show()
            
            # A) Get and Display Euclidean Results
            distances_euc, indices_euc = neighbors_euclidean.kneighbors([query_features])
            print("\n" + "="*50 + "\n      RESULTS USING EUCLIDEAN DISTANCE (Top 10)\n" + "="*50)
            print("--> Lower distance is better.")
            for i in range(10):
                file_code = os.path.basename(filenames[indices_euc[0][i]])
                print(f"  Rank {i+1:02d}: Euclidean Distance = {distances_euc[0][i]:.4f}, File Code = {file_code}")
            
            fig, axes = plt.subplots(2, 5, figsize=(20, 8)); axes = axes.flatten()
            for i, idx in enumerate(indices_euc[0]):
                axes[i].imshow(Image.open(filenames[idx]))
                axes[i].set_title(f"Euc Rec {i+1}\nDist: {distances_euc[0][i]:.2f}"); axes[i].axis('off')
            plt.tight_layout(); plt.show()

            # B) Get and Display Cosine Results
            distances_cos, indices_cos = neighbors_cosine.kneighbors([query_features])
            print("\n" + "="*50 + "\n        RESULTS USING COSINE SIMILARITY (Top 10)\n" + "="*50)
            print("--> Higher similarity is better.")
            for i in range(10):
                similarity = 1 - distances_cos[0][i]
                file_code = os.path.basename(filenames[indices_cos[0][i]])
                print(f"  Rank {i+1:02d}: Cosine Similarity = {similarity:.4f}, File Code = {file_code}")
                
            fig, axes = plt.subplots(2, 5, figsize=(20, 8)); axes = axes.flatten()
            for i, idx in enumerate(indices_cos[0]):
                similarity = 1 - distances_cos[0][i]
                axes[i].imshow(Image.open(filenames[idx]))
                axes[i].set_title(f"Cos Rec {i+1}\nSim: {similarity:.2f}"); axes[i].axis('off')
            plt.tight_layout(); plt.show()
        else:
            print("!!! FAILED: Could not extract features from the cropped garment.")
    else:
        print("\n!!! FAILED: Could not detect a garment in your uploaded image.")
        display(Image.open(io.BytesIO(file_content)))

# --- Widget Setup ---
# We need this to handle the uploaded file data
import io

uploader = widgets.FileUpload(accept='image/*', multiple=False, description='Upload an Image')
output_area = widgets.Output()

def on_upload_change(change):
    with output_area:
        clear_output(wait=True)
        # The new syntax for ipywidgets gets the data from the first item in the value tuple
        if uploader.value:
            find_recommendations(uploader.value[0])

uploader.observe(on_upload_change, names='value')

print("Please upload an image to start the recommendation process.")
display(uploader, output_area)

Please upload an image to start the recommendation process.


FileUpload(value=(), accept='image/*', description='Upload an Image')

Output()