In [1]:
!pip install sentence_transformers

Looking in indexes: https://pypi.org/simple, https://packagecloud.io/github/git-lfs/pypi/simple


In [14]:
from sentence_transformers import SentenceTransformer, util

sentences = ["I'm happy", "I'm full of happiness"]
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# Compute embedding for both lists
embedding_1 = model.encode(sentences[0], convert_to_tensor=True)
embedding_2 = model.encode(sentences[1], convert_to_tensor=True)

util.pytorch_cos_sim(embedding_1, embedding_2)
     


tensor([[0.6003]], device='cuda:0')

In [16]:
# Data from https://faq.ssa.gov/en-US/topic/?id=CAT-01092

faq = {
    "How do I get a replacement Medicare card?": "If your Medicare card was lost, stolen, or destroyed, you can request a replacement online at Medicare.gov.",
    "How do I sign up for Medicare?": "If you already get Social Security benefits, you do not need to sign up for Medicare. We will automatically enroll you in Original Medicare (Part A and Part B) when you become eligible. We will mail you the information a few months before you become eligible.",
    "What are Medicare late enrollment penalties?": "In most cases, if you don’t sign up for Medicare when you’re first eligible, you may have to pay a higher monthly premium. Find more information at https://faq.ssa.gov/en-us/Topic/article/KA-02995",
    "Will my Medicare premiums be higher because of my higher income?": "Some people with higher income may pay a larger percentage of their monthly Medicare Part B and prescription drug costs based on their income. We call the additional amount the income-related monthly adjustment amount.",
    "What is Medicare and who can get it?": "Medicare is a health insurance program for people age 65 or older. Some younger people are eligible for Medicare including people with disabilities, permanent kidney failure and amyotrophic lateral sclerosis (Lou Gehrig’s disease or ALS). Medicare helps with the cost of health care, but it does not cover all medical expenses or the cost of most long-term care.",
}

In [17]:
corpus_embeddings = model.encode(list(faq.values()), convert_to_tensor=True)
print(corpus_embeddings.shape)

torch.Size([5, 384])


In [18]:
user_question = "Do I need to pay more after a raise?"
query_embedding = model.encode(user_question, convert_to_tensor=True)
query_embedding.shape
     

torch.Size([384])

In [19]:
similarities = util.semantic_search(
    query_embedding, corpus_embeddings, top_k=3
)[0]
for i, result in enumerate(similarities):
    corpus_id = result["corpus_id"]
    score = result["score"]
    print(f"Top {i+1} question (p={score}): {list(faq.keys())[corpus_id]}")
    print(f"Answer: {list(faq.values())[corpus_id]}")

Top 1 question (p=0.35796284675598145): Will my Medicare premiums be higher because of my higher income?
Answer: Some people with higher income may pay a larger percentage of their monthly Medicare Part B and prescription drug costs based on their income. We call the additional amount the income-related monthly adjustment amount.
Top 2 question (p=0.2787758708000183): What are Medicare late enrollment penalties?
Answer: In most cases, if you don’t sign up for Medicare when you’re first eligible, you may have to pay a higher monthly premium. Find more information at https://faq.ssa.gov/en-us/Topic/article/KA-02995
Top 3 question (p=0.15840473771095276): How do I sign up for Medicare?
Answer: If you already get Social Security benefits, you do not need to sign up for Medicare. We will automatically enroll you in Original Medicare (Part A and Part B) when you become eligible. We will mail you the information a few months before you become eligible.


In [20]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

In [21]:
encoded_input = tokenizer('This is an example sentence', return_tensors='pt')
encoded_input 

{'input_ids': tensor([[ 101, 2023, 2003, 2019, 2742, 6251,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [22]:
import torch

with torch.no_grad():
    model_output = model(**encoded_input)
model_output

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 0.0366, -0.0162,  0.1682,  ...,  0.0554, -0.1644, -0.2967],
         [ 0.7239,  0.6399,  0.1888,  ...,  0.5946,  0.6206,  0.4897],
         [ 0.0064,  0.0203,  0.0448,  ...,  0.3464,  1.3170, -0.1670],
         ...,
         [ 0.1479, -0.0643,  0.1457,  ...,  0.8837, -0.3316,  0.2975],
         [ 0.5212,  0.6563,  0.5607,  ..., -0.0399,  0.0412, -1.4036],
         [ 1.0824,  0.7140,  0.3986,  ..., -0.2301,  0.3243, -1.0313]]]), pooler_output=tensor([[ 1.3429e-02,  4.0036e-02,  3.0797e-03,  7.7094e-03, -8.5741e-02,
         -3.2874e-02,  4.5395e-02,  5.4421e-02, -6.6219e-02, -3.3736e-02,
         -7.4499e-03,  3.3775e-02, -1.8523e-02, -1.2477e-02, -6.1699e-02,
          7.9306e-02,  9.3979e-02, -2.9625e-02, -1.4692e-02,  5.6033e-02,
          1.1484e-02,  1.1056e-02,  2.2872e-02, -2.9034e-02, -1.8243e-02,
          1.3069e-01, -2.4484e-02,  5.1790e-02,  3.6784e-02,  8.1075e-02,
          8.6604e-02,  3.3905e-04, -

In [24]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output["last_hidden_state"]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
sentence_embeddings

tensor([[ 3.4935e-01,  3.2786e-01,  2.5153e-01,  4.0949e-01,  1.9336e-01,
          1.3698e-02,  2.0331e-01, -3.6653e-02,  3.0651e-01,  1.6284e-01,
          3.1032e-01, -2.7318e-01,  2.0967e-01, -1.3389e-01,  1.5409e-01,
          5.8186e-03,  3.7959e-01, -2.6015e-01, -6.3194e-01,  1.2239e-01,
          1.5349e-01,  2.1933e-01,  1.3236e-01,  1.0302e-02, -2.9390e-01,
         -1.4024e-01, -1.6990e-01,  3.4092e-01,  6.1449e-01, -2.3690e-01,
         -3.7498e-01, -1.6825e-01,  2.7026e-01,  2.3264e-01,  4.2614e-02,
          1.8951e-01, -7.1987e-02,  3.3765e-01, -1.3646e-01,  1.0659e-03,
         -7.0556e-02, -1.8734e-01, -1.0071e-01, -1.4961e-01,  2.0358e-01,
         -4.5650e-01,  1.3550e-02,  7.0592e-02,  2.4943e-01, -1.6088e-01,
         -6.0583e-01, -2.6421e-01, -4.5712e-01, -1.1306e-01,  7.3831e-02,
          2.2935e-01, -6.9612e-02,  3.8385e-01,  1.3755e-01, -1.0263e-01,
          9.2525e-02, -5.4760e-02, -4.6691e-01,  1.1012e-01,  7.2911e-01,
         -3.3417e-02, -7.2486e-03, -7.

In [26]:
import torch.nn.functional as F

sentence_embeddings = F.normalize(sentence_embeddings)
sentence_embeddings

tensor([[ 6.7657e-02,  6.3496e-02,  4.8713e-02,  7.9305e-02,  3.7448e-02,
          2.6528e-03,  3.9375e-02, -7.0984e-03,  5.9361e-02,  3.1537e-02,
          6.0098e-02, -5.2905e-02,  4.0607e-02, -2.5931e-02,  2.9843e-02,
          1.1269e-03,  7.3515e-02, -5.0382e-02, -1.2239e-01,  2.3703e-02,
          2.9727e-02,  4.2477e-02,  2.5634e-02,  1.9952e-03, -5.6919e-02,
         -2.7160e-02, -3.2904e-02,  6.6025e-02,  1.1901e-01, -4.5879e-02,
         -7.2621e-02, -3.2584e-02,  5.2341e-02,  4.5055e-02,  8.2530e-03,
          3.6702e-02, -1.3941e-02,  6.5392e-02, -2.6427e-02,  2.0642e-04,
         -1.3664e-02, -3.6281e-02, -1.9504e-02, -2.8974e-02,  3.9427e-02,
         -8.8409e-02,  2.6243e-03,  1.3671e-02,  4.8306e-02, -3.1157e-02,
         -1.1733e-01, -5.1169e-02, -8.8529e-02, -2.1896e-02,  1.4299e-02,
          4.4417e-02, -1.3482e-02,  7.4339e-02,  2.6638e-02, -1.9876e-02,
          1.7919e-02, -1.0605e-02, -9.0426e-02,  2.1327e-02,  1.4120e-01,
         -6.4717e-03, -1.4038e-03, -1.

In [32]:
sentence_embeddings[0]

tensor([ 6.7657e-02,  6.3496e-02,  4.8713e-02,  7.9305e-02,  3.7448e-02,
         2.6528e-03,  3.9375e-02, -7.0984e-03,  5.9361e-02,  3.1537e-02,
         6.0098e-02, -5.2905e-02,  4.0607e-02, -2.5931e-02,  2.9843e-02,
         1.1269e-03,  7.3515e-02, -5.0382e-02, -1.2239e-01,  2.3703e-02,
         2.9727e-02,  4.2477e-02,  2.5634e-02,  1.9952e-03, -5.6919e-02,
        -2.7160e-02, -3.2904e-02,  6.6025e-02,  1.1901e-01, -4.5879e-02,
        -7.2621e-02, -3.2584e-02,  5.2341e-02,  4.5055e-02,  8.2530e-03,
         3.6702e-02, -1.3941e-02,  6.5392e-02, -2.6427e-02,  2.0642e-04,
        -1.3664e-02, -3.6281e-02, -1.9504e-02, -2.8974e-02,  3.9427e-02,
        -8.8409e-02,  2.6243e-03,  1.3671e-02,  4.8306e-02, -3.1157e-02,
        -1.1733e-01, -5.1169e-02, -8.8529e-02, -2.1896e-02,  1.4299e-02,
         4.4417e-02, -1.3482e-02,  7.4339e-02,  2.6638e-02, -1.9876e-02,
         1.7919e-02, -1.0605e-02, -9.0426e-02,  2.1327e-02,  1.4120e-01,
        -6.4717e-03, -1.4038e-03, -1.5361e-02, -8.7

* https://www.leebutterman.com/2023/06/01/offline-realtime-embedding-search.html 