In [1]:
# %pip install transformers
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import re
import numpy as np

In [2]:
df = pd.read_csv('examples.tsv', delimiter='\t')
df.head()

Unnamed: 0.1,Unnamed: 0,title,abstract
0,0,Analysis of Relative Gene Expression Data Usin...,The two most commonly used methods to analyze ...
1,1,Deep Residual Learning for Image Recognition,Deeper neural networks are more difficult to t...
2,2,A short history ofSHELX,An account is given of the development of the ...
3,3,Basic local alignment search tool,"A new approach to rapid sequence comparison, b..."
4,4,,Random forests are a combination of tree predi...


In [4]:
papers = list()
for index, row in df.iterrows():
    if row['title'] and row['abstract']:
        #sentences = re.split(r"\b[.!?;]\s", str(row['abstract']))
        #sentences = [sentence for sentence in sentences if len(sentence) > 0]
        papers.append(
            {
                'title': row['title'],
                'sentences': row['abstract']
            }
        )

In [5]:
df.iloc[0]['abstract']

'The two most commonly used methods to analyze data from real-time, quantitative PCR experiments are absolute quantification and relative quantification. Absolute quantification determines the input copy number, usually by relating the PCR signal to a standard curve. Relative quantification relates the PCR signal of the target transcript in a treatment group to that of another sample such as an untreated control. The 2 −ΔΔ C T method is a convenient way to analyze the relative changes in gene expression from real-time quantitative PCR experiments. The purpose of this report is to present the derivation, assumptions, and applications of the 2 −ΔΔ C T method. In addition, we present the derivation and applications of two variations of the 2 −ΔΔ C T method that may be useful in the analysis of real-time, quantitative PCR data. '

In [7]:
num_papers = 100
sentences = list()
sent_index = dict()
idx = 0
for paper in papers[:num_papers]:
    sentence=paper['sentences']
    sentences.append(sentence)
    if paper["title"] in sent_index.keys():
        sent_index[paper["title"]].append(idx)
    else: sent_index[paper["title"]] = [idx]
    idx += 1

In [8]:
len(sentences)

100

In [9]:
list(sent_index.keys())[:5]

['Analysis of Relative Gene Expression Data Using Real-Time Quantitative PCR and the 2−ΔΔCT Method',
 'Deep Residual Learning for Image Recognition',
 'A short history ofSHELX',
 'Basic local alignment search tool',
 nan]

In [10]:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')

In [14]:
tokens = {'input_ids': [], 'attention_mask': []}
# encode each sentence and append to dictionary
new_tokens = tokenizer.encode_plus(sentences, max_length=1024,
                                    truncation=True, padding='max_length',
                                    return_tensors='pt')
tokens['input_ids'].append(new_tokens['input_ids'][0])
tokens['attention_mask'].append(new_tokens['attention_mask'][0])
        
tokens['input_ids'] = torch.stack(tokens['input_ids'])
tokens['attention_mask'] = torch.stack(tokens['attention_mask'])

TypeError: TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]

In [10]:
outputs = model(**tokens)
outputs.keys()

odict_keys(['last_hidden_state', 'pooler_output'])

In [11]:
embeddings = outputs.last_hidden_state
embeddings

tensor([[[-5.3435e-01,  1.6028e-01, -6.1676e-02,  ..., -5.4697e-01,
           1.2346e-01,  4.1495e-01],
         [-5.0971e-01, -2.3434e-01,  1.6776e-01,  ..., -6.6199e-01,
           2.4208e-01, -1.2140e-02],
         [-5.9011e-01, -4.8241e-01, -4.6855e-01,  ..., -4.3545e-01,
           3.7199e-01, -4.9600e-01],
         ...,
         [ 1.8631e-02,  3.3102e-01, -3.7121e-02,  ..., -1.2575e-01,
           1.7510e-01,  8.1307e-02],
         [-4.4296e-02,  4.6495e-01, -9.6532e-02,  ..., -1.9146e-01,
           2.0625e-01,  3.0093e-01],
         [-1.0972e-01,  1.7571e-01,  1.9726e-01,  ..., -2.8996e-01,
           1.5726e-01,  5.1086e-01]],

        [[-2.4747e-01, -1.4939e-01,  5.9418e-01,  ..., -4.4344e-01,
          -2.6041e-02,  1.2963e+00],
         [ 2.3194e-01, -2.6046e-01,  7.3989e-01,  ..., -4.4468e-01,
          -2.6440e-01,  1.4715e+00],
         [ 1.8986e-01, -2.1698e-02,  7.3000e-01,  ..., -4.6517e-01,
          -1.4789e-01,  1.3726e+00],
         ...,
         [-4.1239e-02, -2

In [12]:
embeddings.shape

torch.Size([553, 64, 768])

In [13]:
attention_mask = tokens['attention_mask']
attention_mask.shape

torch.Size([553, 64])

In [14]:
mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
mask.shape

torch.Size([553, 64, 768])

In [15]:
mask

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0., 

In [16]:
masked_embeddings = embeddings * mask
masked_embeddings.shape

torch.Size([553, 64, 768])

In [17]:
summed = torch.sum(masked_embeddings, 1)
summed.shape

torch.Size([553, 768])

In [18]:
summed_mask = torch.clamp(mask.sum(1), min=1e-9)
summed_mask.shape

torch.Size([553, 768])

In [49]:
mean_pooled = summed / summed_mask
mean_pooled

tensor([[-0.4531,  0.0464,  0.0839,  ..., -0.4959,  0.1306,  0.4264],
        [ 0.0896, -0.2147,  0.7119,  ..., -0.2331, -0.2212,  1.3551],
        [ 0.0097, -0.3247,  0.0984,  ..., -0.2072, -0.3513,  0.6373],
        ...,
        [-0.8234, -0.3840,  0.1930,  ..., -0.7831, -0.4611,  0.2648],
        [-0.1730,  0.2009,  1.9040,  ...,  0.3567,  0.1157, -0.2297],
        [-0.3987, -0.0754,  0.0825,  ..., -0.2219,  0.9837, -0.3423]],
       grad_fn=<DivBackward0>)

In [50]:
from sklearn.metrics.pairwise import cosine_similarity
# convert from PyTorch tensor to numpy array
mean_pooled = mean_pooled.detach().numpy()

In [55]:
idx_for_calc = 143
title = ""
for key, value in sent_index.items():
    if idx_for_calc in value:
        title = key
        break

paper_sentences = sent_index[title]

sims_before_idx = cosine_similarity(
    [mean_pooled[idx_for_calc]],
    mean_pooled[:idx_for_calc]
)

sims_after_idx = cosine_similarity(
    [mean_pooled[idx_for_calc]],
    mean_pooled[idx_for_calc+1:]
)

sims = np.concatenate((sims_before_idx.flatten(), sims_after_idx.flatten()))

best_match = np.argmax(sims)
best_match_title = ""
for key, value in sent_index.items():
    if best_match in value:
        best_match_title = key
        break

print(title)
print(sentences[idx_for_calc])
print("\n\n")
print(best_match_title)
print(sentences[best_match])
print("\n\n")
print(best_match)
print(sims[best_match])

Fiji: an open-source platform for biological-image analysis
Fiji is a distribution of the popular open-source software ImageJ focused on biological-image analysis



Fiji: an open-source platform for biological-image analysis
Fiji facilitates the transformation of new algorithms into ImageJ plugins that can be shared with end users through an integrated update system



145
0.76168305


In [54]:
print(sims)

[0.46239728 0.33788925 0.39739355 0.3802843  0.33302778 0.37088016
 0.35646805 0.38762814 0.26237863 0.19554248 0.556841   0.395887
 0.3253775  0.26050463 0.55227226 0.42126304 0.33455715 0.42371053
 0.62224925 0.45363986 0.3403346  0.28376842 0.31671005 0.31298366
 0.38045582 0.42452607 0.17142643 0.2811961  0.4088418  0.38938493
 0.4498933  0.3133253  0.40549955 0.6359988  0.25429323 0.40363026
 0.4213499  0.3171335  0.20705165 0.35639337 0.5443872  0.56720245
 0.45477343 0.51819927 0.32603386 0.36495817 0.4596569  0.6425656
 0.25463194 0.34557438 0.24160475 0.23005675 0.4157809  0.22408906
 0.39049423 0.31698838 0.47318757 0.3834569  0.29157725 0.19908924
 0.52320254 0.13875228 0.44894883 0.3382006  0.4406801  0.33522755
 0.23036927 0.2668664  0.5326203  0.24192293 0.47447193 0.30999953
 0.32030275 0.47608387 0.26522398 0.1934872  0.2971031  0.24254312
 0.37986884 0.19779769 0.19779769 0.19779769 0.30786252 0.38047245
 0.33073336 0.2891006  0.17759429 0.3896925  0.49011233 0.5272039