In [None]:
import random
import numpy as np
import os
import random
from PIL import Image
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from tqdm import tqdm
import pickle
import matplotlib.pyplot as plt

In [None]:
root_dir="img"
men_folder=os.listdir(os.path.join(root_dir,"MEN"))
women_folder=os.listdir(os.path.join(root_dir,"WOMEN"))
men_folder=[os.path.join("MEN",i) for i in men_folder]
women_folder=[os.path.join("WOMEN",i) for i in women_folder]
classes=men_folder+women_folder

In [None]:
class CNNBaseNetwork(nn.Module):
    def __init__(self):
        super(CNNBaseNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(64 * 32 * 32, 256)  
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(256, 128)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.pool3(x)
        x = x.view(x.size(0), -1) 
        x = self.fc1(x)
        x = self.relu4(x)
        x = self.fc2(x)
        x = self.relu5(x)
        x = F.normalize(x, p=2, dim=1)
        return x


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNNBaseNetwork().to(device)
state=torch.load("trained_models/cloth_retrieval_10.pth",map_location=torch.device('cpu'))
model.load_state_dict(state["base_net"])

In [None]:
database=[]
labels=[]
for style in classes:
    for imgid in os.listdir(os.path.join(root_dir,style)):
        for i in os.listdir(os.path.join(root_dir,style,imgid)):
            labels.append(os.path.join(root_dir,style,imgid,i))
            img = Image.open(os.path.join(root_dir,style,imgid,i)).convert('RGB')
            img=img.point(lambda p: p / 255.0)
            img=img.resize((256,256))
            transform = transforms.ToTensor()
            image_tensor = transform(img).to(device)
            emb=model(torch.unsqueeze(image_tensor, 0)).detach().cpu()
            emb=torch.squeeze(emb,0)
            database.append(emb)
        
    print(style)


In [None]:
def test_image_embedding(path,model,device):
    img = Image.open(path).convert('RGB')
    img=img.point(lambda p: p / 255.0)
    img=img.resize((256,256))
    transform = transforms.ToTensor()
    image_tensor = transform(img).to(device)
    img_embed=model(torch.unsqueeze(image_tensor, 0)).detach().cpu()
    return img_embed
    

In [None]:
flattened_database = [tensor.flatten().numpy() for tensor in database]

X = np.array(flattened_database)

In [None]:
with open("embeddings.pkl", "wb") as f:
    pickle.dump(X, f)

with open("labels.pkl", "wb") as f:
    pickle.dump(labels, f)           

In [None]:
from sklearn.neighbors import NearestNeighbors
n_neighbors = 5  # Number of neighbors to consider
knn = NearestNeighbors(n_neighbors=n_neighbors, metric='euclidean')
knn.fit(X)

    

In [None]:
#Retrieving most similar garments from database for a test input
test_image_path="test_image.jpg"
query_embedding=test_image_embedding(test_image_path,model,device)
distances, indices = knn.kneighbors(query_embedding)


In [None]:
test_img = Image.open(test_image_path).convert('RGB')
plt.imshow(test_img)
plt.show()

In [None]:
for i in range(n_neighbors):
    similar_shop_image_path=labels[indices[0][i]]
    img = Image.open(similar_shop_image_path).convert('RGB')
    plt.imshow(img)
    plt.show()
    
    

In [None]:
#To get only one output image for a given test input image we can make n_neighbors = 1 before fitting knn