In [1]:
import numpy as np
from gensim.models import Word2Vec
from gensim.models import Doc2Vec
from gensim.models.doc2vec import TaggedDocument
from gensim.models import KeyedVectors

In [4]:
clean_dataset_file = "datasets/global_datasetMin.txt"

## Modèles word2vec

### Entraînement

In [3]:
%%time 
model = Word2Vec(corpus_file=clean_dataset_file, vector_size=100, window=5, min_count=2, workers=8, sg=1)

In [4]:
%%time 
model_200 = Word2Vec(corpus_file=clean_dataset_file, vector_size=200, window=5, min_count=2, workers=8, sg=1)

In [13]:
%%time 
model_300 = Word2Vec(corpus_file=clean_dataset_file, vector_size=300, window=5, min_count=2, workers=8, sg=1)

CPU times: user 15min 12s, sys: 7.89 s, total: 15min 20s
Wall time: 2min 20s


### Sauvegarde des modèles

In [6]:
wv = model.wv
wv.save("models/w2vec_model_d100_1M")

In [7]:
wv = model_200.wv
wv.save("models/w2vec_model_d200_1M")

In [14]:
wv = model_300.wv
wv.save("models/w2vec_model_d300_global_Min")

## Modèle doc2vec

### Préparation des données

In [7]:
with open(clean_dataset_file, "r") as file:
    lines = file.readlines()
    
clean_data = [line.split() for line in lines]

doc_data = [TaggedDocument(words=data, tags=[i]) for i, data in enumerate(clean_data)]

### Entraînement

In [5]:
%%time

doc2v_model = Doc2Vec(corpus_file=clean_dataset_file, vector_size=300, workers=8, dbow_words = 1, min_count=2, epochs=10)

CPU times: user 9min 19s, sys: 5.54 s, total: 9min 25s
Wall time: 1min 24s


### Sauvegarde

In [6]:
doc2v_model.save('doc2v_model_global_Min')

## Exemple d'utilisation

In [9]:
w2v_100d_model = KeyedVectors.load("models/w2vec_model_d100_1M")
model_dim = 100

In [1]:
def encode(msg, model, dim):
    """
    Encode un message.
    msg : liste de chaînes de caractères correspondant aux mots du message
    model : le modèle utilisé pour l'encodage
    dim : la dimension des vecteurs mots dans ce modèle
    Renvoie un vecteur qui est la moyenne de tous les vecteurs correspondants aux mots du message
    Si aucun mot du message n'est dans le modèle, renvoie un vecteur de zéros.
    """
    return np.mean([model[word] for word in msg if word in model] or [np.zeros(dim)], axis = 0)

In [11]:
phrase = "Mon tailleur est riche".split()

encode(phrase, w2v_100d_model, model_dim)

array([ 4.18351516e-02, -3.00552994e-02,  2.93862492e-01, -3.22550237e-01,
        5.67456722e-01, -3.53782058e-01,  6.20573759e-04,  8.60297233e-02,
       -1.15959801e-01, -2.19277292e-01,  4.27981466e-02, -1.07317016e-01,
       -9.58199203e-02,  7.21499473e-02,  2.75918841e-01, -2.22999305e-02,
        3.32399011e-01, -3.17189880e-02, -3.35851997e-01, -1.76247358e-01,
       -1.28344111e-02, -1.09718651e-01, -4.38710690e-01, -4.01031561e-02,
       -1.23104177e-01,  2.94926643e-01, -4.17810678e-01, -1.66892141e-01,
       -5.40134430e-01, -2.85382941e-02,  4.30060804e-01, -1.78062543e-01,
       -1.87055111e-01, -5.38300157e-01,  2.41067857e-02,  1.73097700e-01,
        1.64522097e-01,  2.74127841e-01,  9.55908373e-02, -4.62221682e-01,
        2.32840717e-01, -3.88882101e-01, -3.06208700e-01,  1.65211111e-01,
        3.99650335e-01,  2.25923762e-01, -1.10780224e-01,  5.52508235e-03,
        1.28892869e-01,  2.55674630e-01,  5.34657463e-02, -5.20176888e-02,
       -7.70943537e-02, -

In [8]:
d2v = Doc2Vec.load("models/doc2v_model_global_Min")

In [14]:
d2v.infer_vector(phrase)

array([ 3.25924344e-02,  9.65289548e-02, -1.87063217e-02,  1.41754434e-01,
       -3.84703316e-02, -3.41895707e-02,  4.88149188e-02,  8.10005963e-02,
        1.94979366e-02, -4.11345698e-02, -9.99840908e-03,  2.36345623e-02,
        1.24998698e-02,  1.59241781e-02, -3.27239893e-02,  1.49624618e-02,
        3.52247506e-02, -7.17094466e-02,  5.75929508e-03,  1.43895345e-02,
       -2.88615678e-03,  1.86007116e-02,  7.33369816e-05, -2.89414078e-02,
        8.62125456e-02, -4.36879881e-02,  2.14435384e-02, -1.96394809e-02,
       -2.14762203e-02, -2.17972808e-02, -1.91847719e-02,  1.59782264e-02,
       -3.18497978e-02,  3.07049938e-02, -2.11984292e-03,  6.06949488e-03,
        2.37713382e-02,  1.33918328e-02,  3.61802205e-02,  1.98899099e-04,
       -2.84323655e-02, -1.55254826e-02,  7.29476754e-03, -2.08594166e-02,
       -2.01968756e-02,  1.18922833e-02, -1.18742818e-02,  2.64078658e-02,
        3.19818370e-02,  1.72190424e-02,  8.82046111e-03,  5.03719710e-02,
        1.34704576e-03, -