In [1]:
# References:
# https://github.com/UKPLab/sentence-transformers
# Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks (https://arxiv.org/pdf/1908.10084)
# Data:
# https://sbert.net/datasets/stsbenchmark.tsv.gz

In [2]:
import torch
import torch.nn as nn
from torch import tensor 
from transformers import BertModel, BertTokenizer
import gzip
import pandas as pd

In [3]:
device = "cuda"

In [4]:
# Language model is defined (BERT + pooling)

In [5]:
class EmbeddingModel(nn.Module):
    def __init__(self, bertName = "bert-base-uncased"): # other bert models can also be supported
        super().__init__()
        self.bertName = bertName
        # use BERT model
        self.tokenizer = BertTokenizer.from_pretrained(self.bertName)
        self.model = BertModel.from_pretrained(self.bertName)        
       
    def forward(self, s, device = "cuda"):
        # get tokens, which also include attention_mask
        tokens = self.tokenizer(s, return_tensors='pt', padding = "max_length", truncation = True, max_length = 256).to(device)
        
        # get token embeddings
        output = self.model(**tokens)
        tokens_embeddings = output.last_hidden_state
        #print("tokens_embeddings:" + str(tokens_embeddings.shape))
        
        # mean pooling to get text embedding
        embeddings = tokens_embeddings * tokens.attention_mask[...,None] # [B, T, emb]
        #print("embeddings:" + str(embeddings.shape))
        
        embeddings = embeddings.sum(1) # [B, emb]
        valid_tokens = tokens.attention_mask.sum(1) # [B]
        embeddings = embeddings / valid_tokens[...,None] # [B, emb]    
        
        return embeddings

    # from scratch: nn.CosineSimilarity(dim = 1)(q,a)
    def cos_score(self, q, a): 
        q_norm = q / (q.pow(2).sum(dim=1, keepdim=True).pow(0.5))
        r_norm = a / (a.pow(2).sum(dim=1, keepdim=True).pow(0.5))
        return (q_norm @ r_norm.T).diagonal()

In [6]:
# Language model (such as BERT + pooling model above) can be used directly as embedding model. Train it on data with similarity score to get better result.

In [7]:
# contrastive training
class TrainModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.m = EmbeddingModel("bert-base-uncased")

    def forward(self, s1, s2, score):        
        cos_score = self.m.cos_score(self.m(s1), self.m(s2))
        loss = nn.MSELoss()(cos_score, score)
        return loss, cos_score

In [8]:
train = TrainModel().to(device)

In [9]:
def test_train():
    loss, _ = train(sentences_q, sentences_a, tensor([0.8,0.7,0.7]).to(device))
    print(loss)
    loss.backward()
    with torch.no_grad():
        optimizer.step()
        optimizer.zero_grad()

In [10]:
# training data
#!wget https://sbert.net/datasets/stsbenchmark.tsv.gz

In [11]:
#!gzip -d stsbenchmark.tsv.gz

In [12]:
df = pd.read_csv("stsbenchmark.tsv", delimiter="\t", low_memory = False, on_bad_lines = 'skip',  skiprows=[8300])

In [13]:
df.isna().any(axis=None)

True

In [14]:
df = df.dropna().copy()

In [15]:
df['sentence1'] = df['sentence1'].apply(lambda x: x.strip())

In [16]:
df.head()

Unnamed: 0,split,genre,dataset,year,sid,score,sentence1,sentence2
0,train,main-captions,MSRvid,2012test,1,5.0,A plane is taking off.,An air plane is taking off.
1,train,main-captions,MSRvid,2012test,4,3.8,A man is playing a large flute.,A man is playing a flute.
2,train,main-captions,MSRvid,2012test,5,3.8,A man is spreading shreded cheese on a pizza.,A man is spreading shredded cheese on an uncoo...
3,train,main-captions,MSRvid,2012test,6,2.6,Three men are playing chess.,Two men are playing chess.
4,train,main-captions,MSRvid,2012test,9,4.25,A man is playing the cello.,A man seated is playing the cello.


In [17]:
# normalize to 0...1
df['score_'] = df['score']/5
len(df)

8282

In [18]:
df_train = df[df['split']=='train']
df_eval = df[df['split']=='dev']
df_test = df[df['split']=='test']
len(df_train), len(df_eval), len(df_test)

(5703, 1463, 1116)

In [19]:
df['split'].unique()

array(['train', 'dev', 'test'], dtype=object)

In [20]:
batch_size = 4
epochs = 1

In [21]:
def train_loop(training = True):
    if training == True:
        optimizer = torch.optim.AdamW(train.parameters(), lr=1e-5)
    losses = 0
    losses_eval = 0
    for i in range(0,len(df_train),batch_size):
        #if training == True:
        if 1:
            train.train()
            batch = df_train.iloc[i:i+batch_size]
            
            loss, _ = train(list(batch['sentence1']), list(batch['sentence2']), tensor(list(batch['score_'].apply(lambda x: float(x)))).float().to(device) )
            losses += loss
        
            loss.backward()
            with torch.no_grad():
                optimizer.step()
                optimizer.zero_grad()   
        
        #else:
            train.eval()
            ieval = i % len(df_eval)
            batch = df_eval.iloc[ieval:ieval+batch_size]
            with torch.no_grad():
                loss, _ = train(list(batch['sentence1']), list(batch['sentence2']), tensor(list(batch['score_'].apply(lambda x: float(x)))).float().to(device) )
            losses_eval += loss
            
        if (i % 200 == 0):
            print(f'batch {i}, loss {losses/200} eval {losses_eval/200}')
            losses = 0
            losses_eval = 0
            #if i > 1000:
            #    break

        #break

In [22]:
train_loop()

batch 0, loss 0.00027302277158014476 eval 0.0003261603997088969
batch 200, loss 0.01464193593710661 eval 0.024115221574902534
batch 400, loss 0.006054205354303122 eval 0.01201570499688387
batch 600, loss 0.007819064892828465 eval 0.010829409584403038
batch 800, loss 0.006070663221180439 eval 0.02249162830412388
batch 1000, loss 0.0053120506927371025 eval 0.01738635264337063
batch 1200, loss 0.00723802438005805 eval 0.0139848031103611
batch 1400, loss 0.005778112448751926 eval 0.02004367485642433
batch 1600, loss 0.006132963113486767 eval 0.01131927128881216
batch 1800, loss 0.006502231582999229 eval 0.006938617676496506
batch 2000, loss 0.006008408032357693 eval 0.006858306471258402
batch 2200, loss 0.012104008346796036 eval 0.013000665232539177
batch 2400, loss 0.008949452079832554 eval 0.010031194426119328
batch 2600, loss 0.009327622130513191 eval 0.010410006158053875
batch 2800, loss 0.005697215441614389 eval 0.010961009189486504
batch 3000, loss 0.006327271927148104 eval 0.0132357

In [23]:
torch.save(train,"myTextEmbedding.pt")

In [28]:
train_load=torch.load("myTextEmbedding.pt")

In [54]:
df_eval.head()

Unnamed: 0,split,genre,dataset,year,sid,score,sentence1,sentence2,score_
5706,dev,main-captions,MSRvid,2012test,0,5.0,A man with a hard hat is dancing.,A man wearing a hard hat is dancing.,1.0
5707,dev,main-captions,MSRvid,2012test,2,4.75,A young child is riding a horse.,A child is riding a horse.,0.95
5708,dev,main-captions,MSRvid,2012test,3,5.0,A man is feeding a mouse to a snake.,The man is feeding a mouse to the snake.,1.0
5709,dev,main-captions,MSRvid,2012test,7,2.4,A woman is playing the guitar.,A man is playing guitar.,0.48
5710,dev,main-captions,MSRvid,2012test,8,2.75,A woman is playing the flute.,A man is playing a flute.,0.55


In [29]:
m = train_load.m

In [30]:
for i in range(0,len(df_eval),batch_size):
    batch = df_eval.iloc[i:i+batch_size]
    print(m(list(batch['sentence1'])).shape)
    q = m(list(batch['sentence1']))
    a = m(list(batch['sentence2']))
    print(q)
    print(a)
    print(batch.iloc[0]['sentence1'])
    print(batch['sentence2'])
    scores = m.cos_score(m(list(batch['sentence1'])), m(list(batch['sentence2'])))
    #nn.CosineSimilarity(dim = 1)(q,a)
    print(scores)
    #print(torch.diagonal(scores))
    break

torch.Size([4, 768])
tensor([[-0.4339,  0.7072,  0.4792,  ..., -0.0786, -0.5483,  0.7098],
        [-0.2738, -1.2630, -1.2442,  ..., -0.3334,  0.3641, -0.6543],
        [ 0.7409, -0.2056,  0.0571,  ...,  0.4372,  0.1220,  0.6263],
        [ 0.3777, -1.3719,  0.2469,  ...,  0.4906, -0.4489, -0.2948]],
       device='cuda:0', grad_fn=<DivBackward0>)
tensor([[-0.3643,  0.6227,  0.4585,  ...,  0.0225, -0.5529,  0.7337],
        [-0.4042, -1.1851, -1.1947,  ..., -0.4143,  0.4112, -0.4713],
        [ 0.6979, -0.2874, -0.0107,  ...,  0.5579,  0.2560,  0.6040],
        [ 0.0861, -0.6448,  0.3536,  ...,  0.4708, -0.1786, -0.3542]],
       device='cuda:0', grad_fn=<DivBackward0>)
A man with a hard hat is dancing.
5706        A man wearing a hard hat is dancing.
5707                  A child is riding a horse.
5708    The man is feeding a mouse to the snake.
5709                    A man is playing guitar.
Name: sentence2, dtype: object
tensor([0.9837, 0.9856, 0.9746, 0.6640], device='cuda:0',
  