In [None]:
# --- Imports ---
import os
import zipfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.cluster import KMeans

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

# CLIP for feature extraction
import clip


# --- Step 1: Dataset unzip & prep ---
ZIP_PATH = "images_compressed.zip"
IMG_DIR = "images"
CSV_PATH = "images.csv"

# Unzip images if not already
if not os.path.exists(IMG_DIR):
    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall(IMG_DIR)
    print("âœ… Images extracted to:", IMG_DIR)
else:
    print("âœ… Images already extracted:", IMG_DIR)

# Load CSV
df = pd.read_csv(CSV_PATH)

# Add full image paths (assuming .jpg extension, change if .png)
df['image_path'] = df['image'].apply(lambda x: os.path.join(IMG_DIR, x + ".jpg"))

print("Dataset shape:", df.shape)
print("Columns:", df.columns)
print(df.head())


âœ… Images already extracted: images
Dataset shape: (5403, 5)
Columns: Index(['image', 'sender_id', 'label', 'kids', 'image_path'], dtype='object')
                                  image  sender_id     label   kids  \
0  4285fab0-751a-4b74-8e9b-43af05deee22        124  Not sure  False   
1  ea7b6656-3f84-4eb3-9099-23e623fc1018        148   T-Shirt  False   
2  00627a3f-0477-401c-95eb-92642cbe078d         94  Not sure  False   
3  ea2ffd4d-9b25-4ca8-9dc2-bd27f1cc59fa         43   T-Shirt  False   
4  3b86d877-2b9e-4c8b-a6a2-1d87513309d0        189     Shoes  False   

                                        image_path  
0  images/4285fab0-751a-4b74-8e9b-43af05deee22.jpg  
1  images/ea7b6656-3f84-4eb3-9099-23e623fc1018.jpg  
2  images/00627a3f-0477-401c-95eb-92642cbe078d.jpg  
3  images/ea2ffd4d-9b25-4ca8-9dc2-bd27f1cc59fa.jpg  
4  images/3b86d877-2b9e-4c8b-a6a2-1d87513309d0.jpg  


In [None]:
print(df.columns)


Index(['image', 'sender_id', 'label', 'kids', 'image_path'], dtype='object')


In [None]:

from PIL import Image

def safe_open_image(path):
    try:
        return Image.open(path).convert("RGB")
    except:
        # fallback black image
        return Image.new("RGB", (224,224), (0,0,0))

# --- Step 2: Dataset class ---
class WardrobeDataset(Dataset):
    def __init__(self, df, transform=None, use_metadata=True):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.use_metadata = use_metadata
        self.classes = sorted(df['label'].unique())
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = safe_open_image(row['image_path'])

        if self.transform:
            image = self.transform(image)

        label = self.class_to_idx[row['label']]

        # Metadata (only kids flag here, can extend later)
        metadata = torch.tensor([int(row['kids'])], dtype=torch.float32) if self.use_metadata else torch.zeros(1)

        return image, label, metadata



In [None]:
# --- Step 3: Transforms ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                         std=[0.26862954, 0.26130258, 0.27577711])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                         std=[0.26862954, 0.26130258, 0.27577711])
])



In [None]:
# --- Step 4: CLIP Feature Extractor ---
class CLIPFeatureExtractor(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.model, _ = clip.load("ViT-B/32", device=device)
        self.model.eval()

    def forward(self, images):
        with torch.no_grad():
            features = self.model.encode_image(images)
        return features


In [None]:
import clip
import torch
from PIL import Image
import numpy as np
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

embeddings = []
img_ids = []

for idx, row in tqdm(df.iterrows(), total=len(df)):
    img_path = f"/content/images/{row['image']}.jpg"  # adjust extension if png
    try:
        image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
        with torch.no_grad():
            emb = model.encode_image(image).cpu().numpy().flatten()
        embeddings.append(emb)
        img_ids.append(row['image'])
    except:
        print("Failed:", img_path)

embeddings = np.array(embeddings)
np.save("clip_embs.npy", embeddings)

# Save ids alongside
np.save("img_ids.npy", np.array(img_ids))


 10%|â–‰         | 533/5403 [02:05<27:08,  2.99it/s]

Failed: /content/images/d028580f-9a98-4fb5-a6c9-5dc362ad3f09.jpg


 13%|â–ˆâ–Ž        | 703/5403 [02:44<22:31,  3.48it/s]

Failed: /content/images/1d0129a1-f29a-4a3f-b103-f651176183eb.jpg


 16%|â–ˆâ–Œ        | 861/5403 [03:21<15:42,  4.82it/s]

Failed: /content/images/784d67d4-b95e-4abb-baf7-8024f18dc3c8.jpg


 31%|â–ˆâ–ˆâ–ˆ       | 1662/5403 [06:28<14:04,  4.43it/s]

Failed: /content/images/c60e486d-10ed-4f64-abab-5bb698c736dd.jpg


 33%|â–ˆâ–ˆâ–ˆâ–Ž      | 1763/5403 [06:51<12:53,  4.70it/s]

Failed: /content/images/040d73b7-21b5-4cf2-84fc-e1a80231b202.jpg


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5403/5403 [21:04<00:00,  4.27it/s]


In [None]:
# --- Step 5: Compatibility Model ---
class OutfitCompatibilityNet(nn.Module):
    def __init__(self, embed_dim, metadata_dim=1):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(embed_dim * 2 + metadata_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, top_embed, bottom_embed, metadata=None):
        if metadata is None:
            metadata = torch.zeros(top_embed.size(0), 1, device=top_embed.device)
        x = torch.cat([top_embed, bottom_embed, metadata], dim=1)
        return self.fc(x)



In [None]:
# --- Step 6: Color Harmony Explainability ---
def extract_dominant_colors(img_path, k=3):
    img = Image.open(img_path).convert("RGB")
    img_small = img.resize((64, 64))
    arr = np.array(img_small).reshape(-1, 3)

    kmeans = KMeans(n_clusters=k, n_init=10)
    kmeans.fit(arr)
    colors = kmeans.cluster_centers_.astype(int)
    return colors


def explain_color_match(img1_path, img2_path):
    colors1 = extract_dominant_colors(img1_path, k=3)
    colors2 = extract_dominant_colors(img2_path, k=3)

    for c1 in colors1:
        for c2 in colors2:
            if np.allclose(c1 + c2, [255, 255, 255], atol=50):
                return f"Complementary match: {c1} vs {c2}"
    return "Analogous or neutral harmony detected"



In [None]:
# --- Step 7: Training Loop ---
def train_model(clip_extractor, model, train_loader, criterion, optimizer, device, epochs=5):
    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for images, labels, metadata in train_loader:
            images, labels, metadata = images.to(device), labels.to(device), metadata.to(device)

            embeddings = clip_extractor(images)

            # Pair adjacent samples in batch
            top_embed = embeddings[::2]
            bottom_embed = embeddings[1::2]
            meta = metadata[::2]

            labels_pair = (labels[::2] == labels[1::2]).float().unsqueeze(1)

            preds = model(top_embed, bottom_embed, meta)
            loss = criterion(preds, labels_pair)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}")



In [None]:
# --- Step 8: Main ---
if __name__ == "__main__":
    train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['label'], random_state=42)

    train_dataset = WardrobeDataset(train_df, transform=train_transform)
    val_dataset = WardrobeDataset(val_df, transform=test_transform)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    clip_extractor = CLIPFeatureExtractor(device)
    model = OutfitCompatibilityNet(embed_dim=512, metadata_dim=1).to(device)

    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Train
    train_model(clip_extractor, model, train_loader, criterion, optimizer, device, epochs=5)

    # Demo Explanation
    img1 = df.iloc[0]['image_path']
    img2 = df.iloc[1]['image_path']
    print("Color Explanation:", explain_color_match(img1, img2))

Epoch 1/5, Loss: 46.0534
Epoch 2/5, Loss: 43.2458
Epoch 3/5, Loss: 41.3490
Epoch 4/5, Loss: 41.5357
Epoch 5/5, Loss: 34.4408
Color Explanation: Complementary match: [179 183 152] vs [119  26  57]


In [None]:
import torch
from google.colab import drive
drive.mount('/content/drive')
torch.save(model.state_dict(), "/content/drive/MyDrive/model_weights.pt")

Mounted at /content/drive


In [None]:
torch.save(model.state_dict(), "compatibility_model.pth")


In [None]:
import argparse
import torch

# Load the model weights
model = OutfitCompatibilityNet(embed_dim=512, metadata_dim=1).to(device)
model.load_state_dict(torch.load("compatibility_model.pth", map_location=device))  # saved after training
model.eval()


# --- Prediction function ---
def get_compatibility(img1_id, img2_id):
    row1 = df[df['image'] == img1_id].iloc[0]
    row2 = df[df['image'] == img2_id].iloc[0]

    img1 = Image.open(row1['image_path']).convert("RGB")
    img2 = Image.open(row2['image_path']).convert("RGB")

    img1_t = test_transform(img1).unsqueeze(0).to(device)
    img2_t = test_transform(img2).unsqueeze(0).to(device)

    with torch.no_grad():
        emb1 = clip_extractor(img1_t)
        emb2 = clip_extractor(img2_t)

        metadata = torch.tensor([[int(row1['kids'])]], dtype=torch.float32).to(device)
        score = model(emb1, emb2, metadata)

    explanation = explain_color_match(row1['image_path'], row2['image_path'])
    return score.item(), explanation


# --- Main CLI ---
if __name__ == "__main__":
    # Check if running in Colab
    if 'google.colab' in str(get_ipython()):
        print("Running in Google Colab. Using example image IDs.")
        # Provide example image IDs from your dataframe
        img1_id_example = df.iloc[0]['image']
        img2_id_example = df.iloc[1]['image']
        score, explanation = get_compatibility(img1_id_example, img2_id_example)
        print(f"\nðŸ‘— Compatibility between {img1_id_example} and {img2_id_example}: {score:.2f}")
        print("ðŸŽ¨ Explanation:", explanation)
    else:
        parser = argparse.ArgumentParser(description="Outfit Compatibility Inference")
        parser.add_argument("--img1", type=str, required=True, help="Image ID of first item")
        parser.add_argument("--img2", type=str, required=True, help="Image ID of second item")
        args = parser.parse_args()

        score, explanation = get_compatibility(args.img1, args.arg2)
        print(f"\nðŸ‘— Compatibility between {args.img1} and {args.img2}: {score:.2f}")
        print("ðŸŽ¨ Explanation:", explanation)

Running in Google Colab. Using example image IDs.

ðŸ‘— Compatibility between 4285fab0-751a-4b74-8e9b-43af05deee22 and ea7b6656-3f84-4eb3-9099-23e623fc1018: 0.52
ðŸŽ¨ Explanation: Complementary match: [179 183 152] vs [119  26  57]


In [None]:
%%writefile app.py
import gradio as gr
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.neighbors import NearestNeighbors
import torch
import clip

# Paths
CSV_PATH = 'images.csv'
IMG_DIR = 'images'
#EMB_PATH = 'clip_embs.npy'
import numpy as np

EMB_PATH = "clip_embs.npy"
ID_PATH = "img_ids.npy"

embs = np.load(EMB_PATH)
img_ids = np.load(ID_PATH)


# Load dataset and embeddings
df = pd.read_csv(CSV_PATH)
df['image_path'] = df['image'].apply(lambda x: f'{IMG_DIR}/{x}.jpg')
embs = np.load(EMB_PATH)

# Setup CLIP
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, preprocess = clip.load('ViT-B/32', device=device)
model.eval()

# Build Nearest Neighbors
nn = NearestNeighbors(n_neighbors=100, metric='cosine').fit(embs)

import os

def recommend_full_outfit_ui(base_img_path, topk=5):
    print("Base image received:", base_img_path)
    img_filename = os.path.basename(base_img_path)
    query_str = os.path.join('images', img_filename)

    # Find matching row using the correct column and variable names
    baserow = df[df['image_path'] == query_str]
    if len(baserow) == 0:
        print(f"Uploaded image {img_filename} not found in dataset.")
        return []

    baseidx = baserow.index[0]
    basevec = embs[baseidx].reshape(1, -1)

    dists, inds = nn.kneighbors(basevec, n_neighbors=topk+1)

    recs = []
    for i in inds[0]:
        if i == baseidx:
            continue
        img_path = df.iloc[i]['image_path']
        recs.append((img_path, ""))  # tuple with (image_path, caption)

    return recs


# Gradio interface
demo = gr.Interface(
    fn=recommend_full_outfit_ui,
    inputs=gr.Image(type='filepath', label='Upload base item'),
    outputs=gr.Gallery(label='Recommended Outfit Items'),
    title='Smart Outfit Recommender',
    description='Upload a clothing item to get full outfit recommendations.'
)

demo.launch(share=True)

Overwriting app.py


In [None]:
!python app.py

* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://2ab6ac24e0bf1789be.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
Base image received: /tmp/gradio/9a3134816dd28397a346c644ae1ded989e62c727f5dcdbcbc3a2d5b4ee65442d/0b55a8e8-0087-4c19-8729-b872718ff5ae.jpg
Base image received: /tmp/gradio/9a3134816dd28397a346c644ae1ded989e62c727f5dcdbcbc3a2d5b4ee65442d/0b55a8e8-0087-4c19-8729-b872718ff5ae.jpg
Base image received: /tmp/gradio/6bde9614f79f4a73eac35dcf3f9fb6a47711717ea415da92ec587fc92a899f4d/0b7f4987-34e4-4c85-9f28-35e04ae78ece.jpg
Base image received: /tmp/gradio/6bde9614f79f4a73eac35dcf3f9fb6a47711717ea415da92ec587fc92a899f4d/0b7f4987-34e4-4c85-9f28-35e04ae78ece.jpg
Base image received: /tmp/gradio/fc5b120cc50cd70eb5dc0a4136b06db427a4c18a6dbb8c3500da5c36d6738a1e/0a2668d3-e42a-4f46-bb7f-01