In [6]:
import torch
import requests
from helper import bert_split
import multiprocessing as mp
import matplotlib.pyplot as plt

from sklearn.metrics.pairwise import cosine_similarity

In [1]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/LaBSE")
model = AutoModel.from_pretrained("sentence-transformers/LaBSE")

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=654.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=5220781.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=112.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=49.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=1888175414.0), HTML(value='')))




In [7]:
websites = ['www.epfl.ch/en', 
            'www.epfl.ch/fr',
            'www.caltech.edu/', 
            'www.bbc.com', 
            'www.foxnews.com', 
            'www.skysports.com/', 
            'www.redbull.com/int-en/',
            'www.apple.com/', 
            'www.samsung.com/us/', 
            'www.samsung.com/fr/',
            'www.huawei.com/us/',
            'democrats.org/', 
            'www.spd.de',
            'www.parti-socialiste.fr/', 
            'www.nzz.ch/']

websites_html = [requests.get('https://' + url).text for url in websites]

In [53]:
# from sentence-transformers
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

tokenizer = AutoTokenizer.from_pretrained("pvl/labse_bert", do_lower_case=False)
model = AutoModel.from_pretrained("pvl/labse_bert")

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=472.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=5220781.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=112.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=62.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=1888168867.0), HTML(value='')))




In [62]:
def embed_labse(body):
    
    sentences = bert_split(body)
    
    torch.set_num_threads(1)

    encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')

    with torch.no_grad():
        model_output = model(**encoded_input)

    return mean_pooling(model_output, encoded_input['attention_mask']).mean(axis=0)

In [65]:
pool = mp.Pool(15)
websites_emb = pool.map(embed_labse, websites_html)
pool.close()
pool.join()

In [68]:
websites_emb[0].min()

tensor(-5.7686)

In [69]:
websites_emb[0].max()

tensor(6.0882)

In [71]:
sim = cosine_similarity([w.tolist() for w in websites_emb])

In [72]:
for i in range(len(websites)):
    ix = sim[i].argsort()[-3:][1::-1]
    print('closest to {:<23} : {:>23} ({:.4f}), {:>23} ({:.4f})'.format(websites[i][:22], 
                                                           websites[ix[0]][:22],
                                                           sim[ix[0], i],
                                                           websites[ix[1]][:22],
                                                           sim[ix[1], i]))

closest to www.epfl.ch/en          :          www.epfl.ch/fr (0.9951),        www.caltech.edu/ (0.9814)
closest to www.epfl.ch/fr          :          www.epfl.ch/en (0.9951),        www.caltech.edu/ (0.9755)
closest to www.caltech.edu/        :          www.epfl.ch/en (0.9814),          www.epfl.ch/fr (0.9755)
closest to www.bbc.com             :         www.foxnews.com (0.9775),             www.nzz.ch/ (0.9650)
closest to www.foxnews.com         :             www.bbc.com (0.9775),             www.nzz.ch/ (0.9712)
closest to www.skysports.com/      :         www.foxnews.com (0.9661),  www.redbull.com/int-en (0.9627)
closest to www.redbull.com/int-en  :         www.foxnews.com (0.9698),             www.bbc.com (0.9630)
closest to www.apple.com/          :     www.samsung.com/fr/ (0.9618),     www.samsung.com/us/ (0.9577)
closest to www.samsung.com/us/     :     www.samsung.com/fr/ (0.9820),         www.foxnews.com (0.9629)
closest to www.samsung.com/fr/     :     www.samsung.com/us/ (0.

In [2]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('LaBSE')

In [10]:
def embed_labse_fast(body):
    
    sentences = bert_split(body)
    
    return model.encode(sentences).mean(axis=0)

In [9]:
websites_emb = [embed_labse_fast(w) for w in websites_html]

In [11]:
websites_emb[0].min()

-0.10074113

In [12]:
websites_emb[0].max()

0.017665477

In [14]:
sim = cosine_similarity([w for w in websites_emb])

In [15]:
for i in range(len(websites)):
    ix = sim[i].argsort()[-3:][1::-1]
    print('closest to {:<23} : {:>23} ({:.4f}), {:>23} ({:.4f})'.format(websites[i][:22], 
                                                           websites[ix[0]][:22],
                                                           sim[ix[0], i],
                                                           websites[ix[1]][:22],
                                                           sim[ix[1], i]))

closest to www.epfl.ch/en          :          www.epfl.ch/fr (0.9799),        www.caltech.edu/ (0.9044)
closest to www.epfl.ch/fr          :          www.epfl.ch/en (0.9799),        www.caltech.edu/ (0.8850)
closest to www.caltech.edu/        :          www.epfl.ch/en (0.9044),          www.epfl.ch/fr (0.8850)
closest to www.bbc.com             :         www.foxnews.com (0.9123),             www.nzz.ch/ (0.8661)
closest to www.foxnews.com         :             www.bbc.com (0.9123),             www.nzz.ch/ (0.8857)
closest to www.skysports.com/      :  www.redbull.com/int-en (0.8293),         www.foxnews.com (0.8225)
closest to www.redbull.com/int-en  :         www.foxnews.com (0.8499),              www.spd.de (0.8372)
closest to www.apple.com/          :     www.samsung.com/fr/ (0.8857),     www.samsung.com/us/ (0.8714)
closest to www.samsung.com/us/     :     www.samsung.com/fr/ (0.9564),          www.apple.com/ (0.8714)
closest to www.samsung.com/fr/     :     www.samsung.com/us/ (0.