In [1]:
import os
from tqdm import tqdm
from sklearn.metrics import classification_report
from transformers import BertTokenizer, BertModel, BertConfig
from torch.nn import CosineSimilarity

In [2]:
query = 'A man is eating a food.'
corpus = ['A man is eating a piece of bread.',
          'The girl is carrying a baby.',
          'A man is riding a horse.',
          'A woman is playing violin.',
          'Two men pushed carts through the woods.',
          'A man is riding a white horse on an enclosed ground.',
          'A monkey is playing drums.',
          'A cheetah is running behind its prey.'
          ]

In [7]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

In [8]:
query_tokenized_ids = tokenizer.encode(query, return_tensors='pt')
inputs = tokenizer.batch_encode_plus(corpus, pad_to_max_length=True, return_tensors='pt')['input_ids']

In [12]:
tokenizer.convert_ids_to_tokens(query_tokenized_ids[0])

['[CLS]', 'a', 'man', 'is', 'eating', 'a', 'food', '.', '[SEP]']

In [13]:
last_hidden_states, output_pooled = model(inputs)
query_last_hidden_states, query_pooled = model(query_tokenized_ids)

In [14]:
output_pooled.shape

torch.Size([8, 768])

In [15]:
sent_query_cls = query_last_hidden_states[:, 0, :]
sents_corpus_cls = last_hidden_states[:, 0, :]

In [16]:
sent_query_cls.shape

torch.Size([1, 768])

In [17]:
sents_corpus_cls.shape

torch.Size([8, 768])

In [18]:
cos = CosineSimilarity(dim=1, eps=1e-6)
cos_results = cos(sent_query_cls, sents_corpus_cls)

In [19]:
cos_results

tensor([0.8860, 0.7309, 0.7749, 0.7390, 0.7404, 0.8676, 0.7890, 0.8669],
       grad_fn=<DivBackward0>)

In [20]:
corpus_sorted = [(x, float(y)) for y,x in sorted(zip(cos_results ,corpus), reverse = True)]

In [21]:
print('Text query: ', query)
print('Similarity (sorted): ')
for x, y in corpus_sorted:
    print('\t{}: \t{}'.format(x, y))

Text query:  A man is eating a food.
Similarity (sorted): 
	A man is eating a piece of bread.: 	0.8859956860542297
	A man is riding a white horse on an enclosed ground.: 	0.8675525188446045
	A cheetah is running behind its prey.: 	0.8668953776359558
	A monkey is playing drums.: 	0.7889615297317505
	A man is riding a horse.: 	0.7749326229095459
	Two men pushed carts through the woods.: 	0.7404251098632812
	A woman is playing violin.: 	0.7390106320381165
	The girl is carrying a baby.: 	0.7309496402740479
