<a href="https://colab.research.google.com/github/sambhavpurohit14/Smart_Gallery/blob/text-encoder-trials/text_embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
from transformers import BertTokenizer, BertModel
import torch
import gensim.downloader as api

In [None]:
def get_caption_embeddings(caption, model_type):
    """
    Generate a mean word embedding for the caption using the specified model.
    """
    if model_type == 'cbow':
        cbow_model = api.load('word2vec-google-news-300')

        words = caption.split()
        # get embeddings for words in the vocabulary
        embeddings = [cbow_model[word] for word in words if word in cbow_model]
        if embeddings:
            # return mean embedding
            return np.mean(embeddings, axis=0)
        else:
            return np.zeros(cbow_model.vector_size)

    '''
    tokenizer outputs token ids and attention mask
    returns as tensor

    bert model : input tokenized inputs, outputs last hidden state
    apply attention mask on the inputs - identifies word and padding

    '''

    if model_type == 'bert':
        model_name = "bert-base-uncased"
        tokenizer = BertTokenizer.from_pretrained(model_name)
        bert_model = BertModel.from_pretrained(model_name)
        sentences = [caption]
        inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")

        with torch.no_grad():
            outputs = bert_model(**inputs)
            token_embeddings = outputs.last_hidden_state

        attention_mask = inputs["attention_mask"]
        attention_mask = attention_mask.unsqueeze(-1)  #match shape of token embeddings
        '''dimnension : [batch_size, sequence_length, 1]
        compatible for element-wise multiplication of token embeddings and attention mask, doesnt add any new data (?)
        '''
        sentence_embeddings = torch.sum(token_embeddings * attention_mask, dim=1) / torch.sum(attention_mask, dim=1) #weighted average

        return sentence_embeddings.numpy()


In [None]:
#call function
encoder_type = input("Enter the encoder type (cbow/bert): ")
sentence = input("Enter the sentence: ")
print(get_caption_embeddings(sentence, encoder_type))

In [None]:
print(get_caption_embeddings( 'he discovered a map that will lead him to the treasure island' , 'cbow'))

[ 0.0793457   0.07609864  0.07420655  0.04150391 -0.02399292 -0.0560791
 -0.04064331 -0.13105468  0.06124573  0.09473877  0.04650269 -0.13260803
  0.01254272 -0.02023926 -0.1161438   0.04530334 -0.0320221   0.14226075
  0.02257233  0.01842651  0.02873173  0.00661621  0.01523438  0.05429916
  0.04060059 -0.00440369 -0.07633056  0.01339111  0.04108734 -0.01483154
  0.01257324  0.01239624 -0.05804443  0.01967621  0.02949219 -0.06587219
  0.05095215 -0.02755127  0.07716064  0.0607544   0.06325684 -0.00527649
  0.11011963 -0.03266602 -0.02029419 -0.06052704 -0.06203613  0.01164856
  0.02537842  0.00384521  0.04329834  0.07330932  0.02456665 -0.08997802
 -0.0588501   0.03557129 -0.0566864  -0.16756591  0.05548554 -0.01115417
  0.08276768  0.09409697 -0.02213135 -0.01397095 -0.0727417  -0.02767944
 -0.00895996  0.02041016 -0.01472473  0.02885437  0.05957031 -0.0090477
  0.0791565  -0.04113159 -0.13609314 -0.04261475  0.07692871  0.07792969
  0.06124268  0.03267212  0.05451507 -0.07513428  0.0

BERT

-2.97022969e-01 -3.45540017e-01  1.59468085e-01  1.29468590e-01
   4.95176822e-01 -1.65851742e-01  1.96444586e-01  3.67893666e-01
   6.09932616e-02 -2.23868951e-01 -6.20057061e-02 -7.82726109e-01
   1.73034549e-01  7.72786558e-01 -2.23108456e-01 -1.04516871e-01


-------------
CBOW

0.0793457   0.07609864  0.07420655  0.04150391 -0.02399292 -0.0560791
 -0.04064331 -0.13105468  0.06124573  0.09473877  0.04650269 -0.13260803
  0.01254272 -0.02023926 -0.1161438   0.04530334 -0.0320221   0.14226075
  