In [142]:
import pandas as pd
import numpy as np
import requests
import re
from bs4 import BeautifulSoup

import torch
from transformers import BertTokenizer, BertModel

In [56]:
def clean_html(body):
    """ Parse the html body and returns only the words appearing on the website """
    
    try:
        soup = BeautifulSoup(body, 'html.parser')
        a = soup.get_text(' ').replace('\n', '').lower()
        b = re.sub(r"[^a-z ]+","", a)
        c = ' '.join(b.split())
        return c
    except:
        return ''

In [132]:
def bert_sentence_cleaner(s):
    s = re.sub(r"’", "'", s) # transform the french ' into english ones
    
    match = r"[^a-zA-z\u00C0-\u00FF '’-]" # match anything that is not a letter (incl. accentued), space, apostrophe and dash 
    match += "|" + "(\W\-|\-\W)+" # match any dash that is not between letters
    match += "|" + "(\W'|'\W)+" # match any apostrophe that is not between letters
    
    s = re.sub(match, "", s) # remove the matches characters
    s = re.sub(r"\s+"," ", s) # replace any whitespace with a space
    s = re.sub(r" +"," ", s) # remove any sucession of spaces
    s = s.strip() # trim the final sentence
    return s

In [133]:
def bert_splitter(body):
    soup = BeautifulSoup(body, 'html.parser')
    a = soup.get_text('[SEP]').split('[SEP]')
    b = [s.split('.') for s in a if len(s) != 0]
    flat_b = [clean_sentence(s) for sublist in b for s in sublist]
    trimmed_b = [s for s in flat_b if len(s) != 0]
    return trimmed_b

In [134]:
url = 'https://www.epfl.ch/fr/'
r = requests.get(url)

In [135]:
splitter(r.text)

['École polytechnique fédérale de Lausanne EPFL',
 'Aller directement au contenu',
 'À propos',
 'Éducation',
 'Recherche',
 'Innovation',
 'Facultés',
 'Campus',
 'Info coronavirus',
 'Rechercher sur le site',
 'Valider',
 'Afficher masquer le formulaire de recherche',
 'Rechercher sur le site',
 'Masquer le formulaire de recherche',
 'FR',
 'EN',
 'Menu',
 "Les propriétés de matériaux D modifiées à l'échelle nanométrique",
 'Actualités',
 'EPFL',
 'En savoir plus',
 'sur Tech Transfer',
 "Grâce à une pointe nanométrique des scientifiques de l'EPFL parviennent à modifier les propriétés physiques de matériaux bidimensionnels de façon permanente en les déformant",
 'Une étape importante pour les intégrer dans les dispositifs électroniques et optoélectroniques',
 'En savoir plus',
 'Un mécanisme qui protège les cellules du vieillissement prématuré',
 'Actualités',
 'EPFL',
 "Une nouvelle étude des chercheurs de l'EPFL montre comment les molécules d'ARN appelées TERRA se rassemblent à l'e

In [154]:
df = pd.read_csv('/dlabdata1/lugeon/websites_1000_5cat_emb.gz')
df

Unnamed: 0.1,Unnamed: 0,uid,emb,cat
0,0,100943,,Arts
1,4,77075,"[0.03447194676070164, -0.030272552844447395, -...",Arts
2,5,197517,"[-0.0331517226572725, -0.06529109943325498, -0...",Arts
3,6,146723,"[0.9296558499336243, 0.08666697889566422, 0.73...",Arts
4,7,287942,"[0.16634249522588973, -0.14721398854758158, -0...",Arts
5,8,109175,"[0.09260120582347982, -0.18637233806549813, 0....",Arts


In [148]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased');
model = BertModel.from_pretrained('bert-base-cased', output_hidden_states = True);
model.eval();

In [149]:
def embed(text):
    
    tokenized_text = tokenizer.tokenize(text)
    print(tokenized_text)
    
    if len(tokenized_text) > 510:
        tokenized_text = tokenized_text[:510]
        
    tokenized_text = ['[CLS]'] + tokenized_text + ['[SEP]']
    
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [1] * len(indexed_tokens)
    
    print(indexed_tokens)

    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    
    with torch.no_grad():
        output = model(tokens_tensor, segments_tensors)
        hidden_states = output[2] # tensor of shape 
    
    token_states = torch.stack(hidden_states, dim=0)
    token_states = torch.squeeze(token_states, dim=1) 
    token_embeddings = token_states[-2] 
    
    return torch.mean(token_embeddings, dim=0)