<a href="https://colab.research.google.com/github/rahiakela/transformers-research-and-practice/blob/main/sentence-transformer-works/01_sentence_embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Sentence Embeddings

**Reference**:

[Computing Sentence Embeddings](https://www.sbert.net/examples/applications/computing-embeddings/README.html)

In [None]:
!pip install sentence-transformers

In [10]:
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModel

import torch

import pickle

##Computing Sentence Embeddings

In [None]:
model = SentenceTransformer("all-MiniLM-L6-v2")

In [5]:
# Our sentences we like to encode
sentences = [
    "This framework generates embeddings for each input sentence",
    "Sentences are passed as a list of string.",
    "The quick brown fox jumps over the lazy dog."
]

# Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)

# Print the embeddings
for sentence, embedding in zip(sentences, embeddings):
  print(f"Sentence: {sentence}")
  print(f"Embedding: {embedding}")
  print("#"*20)

Sentence: This framework generates embeddings for each input sentence
Embedding: [-1.37173552e-02 -4.28515449e-02 -1.56286024e-02  1.40537303e-02
  3.95537727e-02  1.21796280e-01  2.94334106e-02 -3.17524187e-02
  3.54959629e-02 -7.93139935e-02  1.75878741e-02 -4.04369719e-02
  4.97259349e-02  2.54912246e-02 -7.18700588e-02  8.14968869e-02
  1.47069141e-03  4.79626991e-02 -4.50336412e-02 -9.92174670e-02
 -2.81769745e-02  6.45046085e-02  4.44670543e-02 -4.76217009e-02
 -3.52952331e-02  4.38671783e-02 -5.28566055e-02  4.33063833e-04
  1.01921506e-01  1.64072234e-02  3.26996595e-02 -3.45986746e-02
  1.21339476e-02  7.94870779e-02  4.58345609e-03  1.57777797e-02
 -9.68206208e-03  2.87625659e-02 -5.05805984e-02 -1.55793717e-02
 -2.87906546e-02 -9.62280575e-03  3.15556750e-02  2.27349028e-02
  8.71449187e-02 -3.85027491e-02 -8.84718448e-02 -8.75498448e-03
 -2.12343335e-02  2.08923239e-02 -9.02077407e-02 -5.25732562e-02
 -1.05638904e-02  2.88310610e-02 -1.61455162e-02  6.17837207e-03
 -1.23234

In [6]:
print("Max Sequence Length:", model.max_seq_length)

Max Sequence Length: 256


##Storing & Loading Embeddings

In [8]:
# Store sentences & embeddings on disc
with open("embeddings.pkl", "wb") as f_out:
  pickle.dump({
        "sentences": sentences,
        "embeddings": embeddings
      },
      f_out,
      protocol=pickle.HIGHEST_PROTOCOL
    )

In [9]:
# Load sentences & embeddings from disc
with open("embeddings.pkl", "rb") as f_in:
  stored_data = pickle.load(f_in)
  stored_sentences = stored_data["sentences"]
  stored_embeddings = stored_data["embeddings"]

# Print the embeddings
for sentence, embedding in zip(stored_sentences, stored_embeddings):
  print(f"Sentence: {sentence}")
  print(f"Embedding: {embedding}")
  print("#"*20)

Sentence: This framework generates embeddings for each input sentence
Embedding: [-1.37173552e-02 -4.28515449e-02 -1.56286024e-02  1.40537303e-02
  3.95537727e-02  1.21796280e-01  2.94334106e-02 -3.17524187e-02
  3.54959629e-02 -7.93139935e-02  1.75878741e-02 -4.04369719e-02
  4.97259349e-02  2.54912246e-02 -7.18700588e-02  8.14968869e-02
  1.47069141e-03  4.79626991e-02 -4.50336412e-02 -9.92174670e-02
 -2.81769745e-02  6.45046085e-02  4.44670543e-02 -4.76217009e-02
 -3.52952331e-02  4.38671783e-02 -5.28566055e-02  4.33063833e-04
  1.01921506e-01  1.64072234e-02  3.26996595e-02 -3.45986746e-02
  1.21339476e-02  7.94870779e-02  4.58345609e-03  1.57777797e-02
 -9.68206208e-03  2.87625659e-02 -5.05805984e-02 -1.55793717e-02
 -2.87906546e-02 -9.62280575e-03  3.15556750e-02  2.27349028e-02
  8.71449187e-02 -3.85027491e-02 -8.84718448e-02 -8.75498448e-03
 -2.12343335e-02  2.08923239e-02 -9.02077407e-02 -5.25732562e-02
 -1.05638904e-02  2.88310610e-02 -1.61455162e-02  6.17837207e-03
 -1.23234

##Sentence Embeddings with Transformers

In [None]:
# Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

In [13]:
# Our sentences we like to encode
sentences = [
    "This framework generates embeddings for each input sentence",
    "Sentences are passed as a list of string.",
    "The quick brown fox jumps over the lazy dog."
]

# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')

# Compute token embeddings
with torch.no_grad():
  model_output = model(**encoded_input)

In [14]:
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

In [15]:
# Perform pooling. In this case, mean pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

In [16]:
# Print the embeddings
for sentence, embedding in zip(sentences, sentence_embeddings):
  print(f"Sentence: {sentence}")
  print(f"Embedding: {embedding}")
  print("#"*20)

Sentence: This framework generates embeddings for each input sentence
Embedding: tensor([-7.4600e-02, -2.3304e-01, -8.4994e-02,  7.6429e-02,  2.1511e-01,
         6.6237e-01,  1.6007e-01, -1.7268e-01,  1.9304e-01, -4.3134e-01,
         9.5649e-02, -2.1991e-01,  2.7043e-01,  1.3863e-01, -3.9086e-01,
         4.4321e-01,  7.9981e-03,  2.6084e-01, -2.4491e-01, -5.3958e-01,
        -1.5324e-01,  3.5080e-01,  2.4183e-01, -2.5898e-01, -1.9195e-01,
         2.3857e-01, -2.8745e-01,  2.3552e-03,  5.5429e-01,  8.9228e-02,
         1.7783e-01, -1.8816e-01,  6.5989e-02,  4.3228e-01,  2.4926e-02,
         8.5805e-02, -5.2655e-02,  1.5642e-01, -2.7508e-01, -8.4726e-02,
        -1.5657e-01, -5.2332e-02,  1.7161e-01,  1.2364e-01,  4.7393e-01,
        -2.0939e-01, -4.8114e-01, -4.7613e-02, -1.1548e-01,  1.1362e-01,
        -4.9058e-01, -2.8591e-01, -5.7450e-02,  1.5679e-01, -8.7805e-02,
         3.3600e-02, -6.7019e-02, -5.8374e-02,  1.5410e-01, -2.8745e-01,
        -1.9503e-01, -3.2521e-01, -5.9308e-