<a href="https://colab.research.google.com/github/vmunagal/FashionCLIPModel/blob/main/FashionCLIPModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd 
import pickle 
import os 
from PIL import Image

import glob

from IPython.display import display
from IPython.display import Image as IPImage
import warnings 
warnings.filterwarnings('ignore')


In [None]:
with open("/content/drive/MyDrive/pickels/16k_apperal_data_preprocessed","rb") as input_file:
    
    output_file =pickle.load(input_file)

In [None]:
output_file['product_type_name'].unique()

In [None]:
output_file.columns

In [None]:
output_file.head(5)

In [None]:
output_file = output_file.reset_index(drop=True)

In [None]:
def image_fn(image):
    return image+'.jpeg'

In [None]:
output_file['Image_name']=output_file.asin.apply(image_fn)

In [None]:
! mkdir images 

In [None]:
! cp -r /content/drive/MyDrive/16k_images  images 

In [None]:
display(IPImage(('/content/images/16k_images/0000000060.jpeg'), width=200))

In [None]:
! pip install transformers
! pip install timm
! pip install sentence_transformers

In [None]:
import cv2
import torch 
import numpy as np 
import albumentations as A
import timm
import torch.nn as nn
from torch.optim import AdamW 
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
import torchvision.models as models
from torch.optim.lr_scheduler import ReduceLROnPlateau 
from transformers import get_linear_schedule_with_warmup
from torch.nn.functional import normalize
from torch.nn.functional import softmax
from sentence_transformers import util

In [None]:
from transformers import AutoTokenizer , AutoModel

class config:
    dimension = 256 # both image & text will be resized to 256 
    text_model='distilbert-base-uncased'
    image_model = 'resnet34'
    image_path = '/content/images/16k_images'
    text_path = '/content/drive/MyDrive/pickels/'
    tokenizer=AutoTokenizer.from_pretrained(text_model)
    train_batch_size = 32
    eval_batch_size = 16 
    epochs=10
    image_size = 250
    sentence_max_lenght=30
    image_output_dim = 256 
    text_output_dim=256
    output_dim_text = 768
    outout_dim_image=512



In [None]:
class FashionDataset:
    def __init__(self,image,text):

        self.image=image 
        self.text=text 

        self.image_aug= A.Compose(
            [
                A.Resize(config.image_size,config.image_size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )
    def __len__(self):

        return len(self.text)  

    def __getitem__(self,idx):
         
      
        text=self.text[idx] 

        text=config.tokenizer(text,max_length=config.sentence_max_lenght,truncation=True,padding='max_length')


       
        image = cv2.imread(f"{config.image_path}/{self.image[idx]}") # reading the image 
    
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # converting to grey scale image 

        image = self.image_aug(image=image)['image']
    
    
        return {
        
         'image':torch.tensor(image).permute(2, 0, 1).float(),
         'ids':torch.tensor(text['input_ids'],dtype=torch.long),
         'mask':torch.tensor(text['attention_mask'],dtype=torch.long)


         }
        



In [None]:
class FashionEncodeImageModel(nn.Module):
    def __init__(self):

        super(FashionEncodeImageModel,self).__init__()



        self.restnet50model = timm.create_model(config.image_model,num_classes=0,pretrained=True, global_pool="avg")

   
    def forward(self,image_data):
        output=self.restnet50model(image_data)  
        return output  





In [None]:
class FashionTextModel(nn.Module):
    def __init__(self):
        super(FashionTextModel,self).__init__()

        self.bert_model = AutoModel.from_pretrained(config.text_model)
   


    def forward(self,input_ids,attention_mask):
        output = self.bert_model(input_ids=input_ids,attention_mask=attention_mask)

        output=output.last_hidden_state[:,0,:]

        return output 

In [None]:
class FashionClipModel(nn.Module):
    def __init__(self):

        # text model dim conversion 

        super(FashionClipModel,self).__init__()
        

    

        self.text_embedding=nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512,config.text_output_dim),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.LayerNorm(config.text_output_dim)
          
        )
        # Image model dim conversion  

        self.image_embeding=nn.Sequential(

            nn.Linear(512,config.text_output_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.LayerNorm(config.image_output_dim)
        )

        self.text_model =FashionTextModel()
        self.image_model = FashionEncodeImageModel() 


        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self,image,ids,attention_mask):

        out_image = self.image_model(image) # reset model output 

        text_output = self.text_model(ids,attention_mask) # output from bert model 


        text_output= self.text_embedding(text_output)

        out_image=self.image_embeding(out_image)



        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale *  out_image @ text_output.T # cosine similarity calculation .
        logits_per_text = logits_per_image.T


        return logits_per_image , logits_per_text

       






In [None]:
model=FashionClipModel()

In [None]:
DEVICE=('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model=model.to(DEVICE)

In [None]:
print(f'# of trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')
print(f'# of non-trainable params: {sum(p.numel() for p in model.parameters() if not p.requires_grad):,}')

In [None]:
output_file.columns

In [None]:
fashion_dataset= output_file[['title','Image_name']]

In [None]:
fashion_dataset.drop(11195,axis=0,inplace=True)

In [None]:
# This images doesnt contain the data in the 16k images which we want to train 
image_name=[
'B07147JSY5.jpeg',
'B0714D3YBH.jpeg',
'B0714DWP9R.jpeg',
'B0714BMPJQ.jpeg',
'B071454LGB.jpeg',
'B0714DWVNG.jpeg',
'B0714B1RG2.jpeg',
'B07148TD5Z.jpeg',
'B07145W3G3.jpeg',
'B071469DM9.jpeg',
'B07145F1VR.jpeg',
'B0714C9M7S.jpeg',
'B07148V8SN.jpeg',
'B07148BVMF.jpeg',
'B07144VH39.jpeg',
'B0714DWR46.jpeg',
'B07146SR2T.jpeg',
'B07144GTFJ.jpeg',
'B0714CXPTN.jpeg'
]

In [None]:
for x in image_name:
    fashion_dataset.drop(fashion_dataset[fashion_dataset['Image_name']==x].index,axis=0,inplace=True) 


In [None]:
fashion_dataset=fashion_dataset.reset_index(drop=True)

In [None]:
fashion_dataset[fashion_dataset['Image_name']=='B0714CXPTN.jpeg']

In [None]:
def loss_fn(input , target):

    loss_cal=nn.CrossEntropyLoss()
    loss=loss_cal(input,target)

    return loss


In [None]:
class CrossEntropyLossManual:
    """
    y0 is the vector with shape (batch_size,C)
    x shape is the same (batch_size), whose entries are integers from 0 to C-1
    """

    def __init__(self, ignore_index=-100) -> None:
        self.ignore_index = ignore_index

    def __call__(self, y0, x):
        loss = 0.
        n_batch, n_class = y0.shape
        for y1, x1 in zip(y0, x):
            class_index = int(x1.item())

            if class_index == self.ignore_index:
                n_batch -= 1
                continue
            loss = loss + torch.log(torch.exp(y1[class_index]) / (torch.exp(y1).sum()))
        loss = - loss / n_batch
        return loss


loss=CrossEntropyLossManual()

In [None]:
def train_fn(model,dataloader,optimizer,scheduler):

    model.train()

    train_loss=0

    tq=tqdm(dataloader , total=len(dataloader))

    for batch_size , data in enumerate(tq):


        count = data['image'].size(0)

    

        image=data['image']
        ids=data['ids']
        attention_mask = data['mask']

        image = image.to(DEVICE)

        ids=ids.to(DEVICE)

        attention_mask = attention_mask.to(DEVICE)

        optimizer.zero_grad()



         
        logits_per_image, logits_per_text = model(image,ids,attention_mask)
        

        #  calculate the loss 

        target = torch.arange(count)

        target=target.to(DEVICE)

       

        loss_image = loss_fn(logits_per_image,target)

        loss_text =  loss_fn(logits_per_text, target.T)


        total_loss = loss_image+loss_text/2.0

  

        train_loss+=total_loss.item()

       


        total_loss.backward()


        nn.utils.clip_grad_value_(model.parameters(), clip_value=2.0)



        optimizer.step()
     
    return train_loss/len(dataloader)


In [None]:
@torch.no_grad()
def eval_fn(model,dataloader):

    model.eval()

    eval_loss=0

    tqd=tqdm(dataloader , total=len(dataloader))

    for batch_size , data in enumerate(tqd):

        count = data['image'].size(0)

        image=data['image']
        ids=data['ids']
        attention_mask = data['mask']

        image = image.to(DEVICE)

        ids=ids.to(DEVICE)

        attention_mask = attention_mask.to(DEVICE)


        logits_per_image, logits_per_text = model(image,ids,attention_mask)


            # calculate the loss 

        target = torch.arange(count)

        target=target.to(DEVICE)

        texts_loss = loss_fn(logits_per_text , target)
        images_loss = loss_fn(logits_per_image, target.T)
        loss =  (images_loss + texts_loss) / 2.0 


        eval_loss+=loss.item()
        
     
    return eval_loss/len(dataloader)







In [None]:
def run():
    x_train , x_test = train_test_split(fashion_dataset,test_size=0.33 , random_state=42)

    x_train = x_train.reset_index(drop=True)

    x_test = x_test.reset_index(drop=True)

    train_dataset= FashionDataset(x_train.Image_name,x_train.title)

    test_dataset = FashionDataset(x_test.Image_name,x_test.title)

    train_data_loader=torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train_batch_size
    )
    
    valid_data_loader=torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.eval_batch_size
    )


    num_train_steps = len(train_dataset)/config.train_batch_size* config.epochs

    # optimized_parameters 

    optimizer_parameters = list(model.parameters())


    optimizer=AdamW(optimizer_parameters,lr=2e-5)

 
    scheduler=ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
       patience=1
    
    )
    
    path='/content/drive/MyDrive/FashionModel/model1.pt'

    
    best_loss = float('inf')

 
    for x in range(config.epochs):

        print(f' Current epoch value was {x+1} out of {config.epochs}')

        
        loss_train = train_fn(model,train_data_loader,optimizer,scheduler)
        
        loss_val = eval_fn( model,valid_data_loader)

     
        
        print(f' The epoch {x+1} encountered a  train loss of : {loss_train} and validation loss of {loss_val}')  

        scheduler.step(loss_val)


        if loss_val <  best_loss:

            torch.save(model.state_dict(), path)

            print("Saved Best Model!")
    




In [None]:
! cd images

In [None]:
run()

# Inference

In [None]:
DEVICE=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model=FashionClipModel()

In [None]:
model.load_state_dict(torch.load('/content/drive/MyDrive/FashionModel/model.pt',map_location=DEVICE))
# reloading the model 

model.to(DEVICE)


In [None]:
x_train , x_test = train_test_split(fashion_dataset,test_size=0.33 , random_state=42)

x_train = x_train.reset_index(drop=True)

x_test = x_test.reset_index(drop=True)

train_dataset= FashionDataset(x_train.Image_name,x_train.title)

test_dataset = FashionDataset(x_test.Image_name,x_test.title)

train_data_loader=torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train_batch_size
    )
    
valid_data_loader=torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.eval_batch_size
    )


In [None]:
model.eval()

In [None]:
x_test.iloc[18,:].values

In [None]:
! cd images

In [None]:
! ls -la

In [None]:
def get_image_embeddings(valid_data_loader):
 
    
    valid_image_embeddings = []
    with torch.inference_mode(mode=True):
        for data in tqdm(valid_data_loader):
            image_features = model.image_model(data['image'].to(DEVICE))
            image_embeddings = model.image_embeding(image_features)
            valid_image_embeddings.append(image_embeddings)
    return torch.cat(valid_image_embeddings)

In [None]:
output=get_image_embeddings(valid_data_loader)

In [None]:
images_names=list(x_test.iloc[:,1])

In [None]:
import matplotlib.pyplot as plt 

In [None]:
@torch.no_grad()
def find_match_text_image(model, image_embeddings, text,image_file):
    text=config.tokenizer(text,max_length=config.sentence_max_lenght,truncation=True,padding='max_length')
    
    input_ids = torch.tensor(text['input_ids']).unsqueeze(0)

    attention_mask = torch.tensor(text["attention_mask"]).unsqueeze(0)
    
    with torch.inference_mode(mode=True):
        text_features = model.text_model(
            input_ids=input_ids.to(DEVICE) ,attention_mask=attention_mask.to(DEVICE)
        )
        text_embeddings = model.text_embedding(text_features)

      

 
    image_embeddings_n = normalize(image_embeddings, p=2)
   
    text_embedding = normalize(text_embeddings, p=2)

    dot_similarity = text_embeddings @ image_embeddings.T * model.logit_scale

    # output=util.semantic_search(text_embeddings , image_embeddings ,top_k=6)


    
    # values, indices = torch.topk(dot_similarity.squeeze(0), 9 * 5)

    values, indices = torch.topk(dot_similarity.squeeze(0), 5*6)
    matches = [image_file[idx] for idx in indices[::5]]

    # top_k_images=[]
    # for x in output[0]:
    #     top_k_images.append(image_file[x['corpus_id']])

    
    _, axes = plt.subplots(2, 3, figsize=(10, 10))
    for match , ax  in zip(matches ,axes.flatten()):
   
        image = cv2.imread(f"{config.image_path}/{match}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ax.imshow(image)
        ax.axis("off")
       
       


In [None]:
display(IPImage(('/content/images/16k_images/B06ZXWCW8L.jpeg'), width=200))

In [None]:
data='bar iii sleeveless buttondown tunic deep black xs'
find_match_text_image(model,output,data,images_names)