In [None]:
from PIL import Image
import torchvision.transforms as transforms
import torch
import random
import numpy as np
import os
import random
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms


In [None]:
def create_pairs():
    pairs = []
    labels = []
    root_dir="img"
    men_folder=os.listdir(os.path.join(root_dir,"MEN"))
    women_folder=os.listdir(os.path.join(root_dir,"WOMEN"))
    men_folder=[os.path.join("MEN",i) for i in men_folder]
    women_folder=[os.path.join("WOMEN",i) for i in women_folder]
    classes=men_folder+women_folder
    for style in classes:
        style_folder = os.path.join(root_dir,style)
        img_files = os.listdir(style_folder)
        
        for i in range(len(img_files)-1):
            img1=os.listdir(os.path.join(style_folder,img_files[i]))[0]
            img1=os.path.join(style_folder,img_files[i],img1)
            num_elements_to_select = min(17,len(img_files))

            random_elements = random.sample(img_files, num_elements_to_select)
            for j in range(len(random_elements)):
                img2=os.listdir(os.path.join(style_folder,random_elements[j]))[0]
                img2=os.path.join(style_folder,random_elements[j],img2)
                pairs.append([img1,img2])
                labels.append(1)
                
            neg_style=[i for i in classes if i!=style]
            for k in neg_style:
                neg_style_folder=os.path.join(root_dir,k)
                neg_img_files=os.listdir(neg_style_folder)
                random_neg_img_file=random.choice(neg_img_files)
                img2=os.listdir(os.path.join(neg_style_folder,random_neg_img_file))[0]
                img2=os.path.join(neg_style_folder,random_neg_img_file,img2)
                pairs.append([img1,img2])
                labels.append(0)
                    
        print(style)
    return np.array(pairs), np.array(labels)

In [None]:
tr_pairs, tr_y = create_pairs()

In [None]:
batch_size = 256

In [None]:

preprocess = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, pair_paths,labels, transform):
        self.pair_paths = pair_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.pair_paths)

    def __getitem__(self, index):
        label = self.labels[index]
        img1_path, img2_path = self.pair_paths[index]
        img1 = Image.open(img1_path).convert('RGB')
        img2 = Image.open(img2_path).convert('RGB')
        img1 = self.transform(img1)
        img2 = self.transform(img2)
        return img1/255.0, img2/255.0 ,label

custom_dataset = CustomDataset(tr_pairs,tr_y, transform=preprocess)

data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=batch_size, shuffle=True)


In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
    
    def forward(self, euclidean_distance, target):
        loss = torch.mean(target * torch.pow(euclidean_distance, 2) +
                          (1 - target) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss

In [None]:
class CNNBaseNetwork(nn.Module):
    def __init__(self):
        super(CNNBaseNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(64 * 32 * 32, 256)  
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(256, 128)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.pool3(x)
        x = x.view(x.size(0), -1) 
        x = self.fc1(x)
        x = self.relu4(x)
        x = self.fc2(x)
        x = self.relu5(x)
        x = F.normalize(x, p=2, dim=1)
        return x

class SiameseNetwork(nn.Module):
    def __init__(self, base_network):
        super(SiameseNetwork, self).__init__()
        self.base_network = base_network
    
    def forward(self, input_a, input_b):
        processed_a = self.base_network(input_a)
        processed_b = self.base_network(input_b)
        distance = torch.norm(processed_a - processed_b, dim=1, keepdim=True)
        return distance

epochs = 20

base_network = CNNBaseNetwork()


siamese_network = SiameseNetwork(base_network)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
siamese_network.to(device)

criterion = ContrastiveLoss()

optimizer = optim.RMSprop(siamese_network.parameters())



In [None]:
state={"base_net":base_network.state_dict(),"opt":optimizer.state_dict()}

In [None]:
for epoch in range(epochs):
    running_loss = 0.0
    for img1,img2, labels in tqdm(data_loader, desc=f'Epoch {epoch + 1}/{epochs}'):
        img1,img2,labels=img1.to(device),img2.to(device),labels.to(device)
        optimizer.zero_grad()
        outputs = siamese_network(img1, img2)
        loss = criterion(outputs, labels.unsqueeze(1).float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
    torch.save(state,f"cloth_retrieval_{epoch}.pth")   
    print(f'Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(data_loader)}')

print('Training finished')
