In [1]:
import sys
sys.path.append('../w2v_service')

import json
import pickle

from gensim.models import Word2Vec
import gensim.downloader as api
import numpy

In [2]:
from w2v.pip_loss import calculate_pip_distance
from w2v.utils import load_ref_data, load_vocabulary

In [3]:
REF_DATA, REF_MATRIX = load_ref_data('../data')
VOCABULARY = load_vocabulary('../data/preprocessed_dataset.pickle')

In [4]:
model = Word2Vec.load(
    "../third/materials-word-embeddings/bin/word2vec_embeddings-SNAPSHOT.model")
print(model)
pip_dist_base = calculate_pip_distance(REF_DATA, REF_MATRIX, model)
print(pip_dist_base)
print(model.epochs)

Word2Vec(vocab=171603, size=100, alpha=0.025)
0.2906765998675416
5


In [5]:
from functools import partial
from gensim.models.callbacks import CallbackAny2Vec


pip_dist_calc = partial(calculate_pip_distance, REF_DATA, REF_MATRIX)

class EpochLoss(CallbackAny2Vec):
    def __init__(self):
        self.epoch = 0
    
    def on_epoch_end(self, model):
        self.epoch += 1
        print(f"{self.epoch}: {pip_dist_calc(model)} {model.alpha}")

In [6]:
model.__dict__

{'max_final_vocab': None,
 'callbacks': (),
 'load': <function gensim.utils.call_on_class_only(*args, **kwargs)>,
 'wv': <gensim.models.keyedvectors.Word2VecKeyedVectors at 0x187c670cc88>,
 'vocabulary': <gensim.models.word2vec.Word2VecVocab at 0x187d321afc8>,
 'trainables': <gensim.models.word2vec.Word2VecTrainables at 0x187c1d19b88>,
 'sg': 0,
 'alpha': 0.025,
 'window': 5,
 'random': RandomState(MT19937) at 0x187C10BEBF8,
 'min_alpha': 0.0001,
 'hs': 0,
 'negative': 5,
 'ns_exponent': 0.75,
 'cbow_mean': 1,
 'compute_loss': False,
 'running_training_loss': 0,
 'min_alpha_yet_reached': 0.025,
 'corpus_count': 2616049,
 'corpus_total_words': None,
 'vector_size': 100,
 'workers': 20,
 'epochs': 5,
 'train_count': 6420,
 'total_train_time': 15857.520852804184,
 'batch_words': 10000,
 'model_trimmed_post_training': False}

In [7]:
model.build_vocab(VOCABULARY, update=True, progress_per=1)

In [8]:
model.train(VOCABULARY,
             total_examples=len(VOCABULARY),
             start_alpha=.1,
             end_alpha=0.5,
             epochs=15,
             callbacks=[EpochLoss()])

1: 0.29641261129393787 0.1
2: 0.29990413687382783 0.1
3: 0.2923346452257005 0.1
4: 0.2891877809434924 0.1
5: 0.2946684803206732 0.1
6: 0.28194451422097055 0.1
7: 0.2789140782665042 0.1
8: 0.29272589552094413 0.1
9: 0.2980959080283236 0.1
10: 0.2933887794616834 0.1
11: 0.2984549844934341 0.1
12: 0.305764577250028 0.1
13: 0.31352894005949355 0.1
14: 0.30382291213919077 0.1
15: 0.2932706116461968 0.1


(36584802, 38101440)

In [9]:
for ref in REF_DATA:
    try:
        print(f"{ref}: {model.wv.most_similar(positive=[ref])[:3]}")
    except KeyError as err:
        print(err)

austenite: [('nearly', 0.7236950397491455), ('ferrite', 0.699272632598877), ('martensite', 0.6804725527763367)]
cobalt: [('slow', 0.7201051712036133), ('cnt', 0.7187455892562866), ('isolated', 0.673039436340332)]
thickness: [('polishing', 0.7641053795814514), ('rram', 0.6823971271514893), ('steps', 0.6806197166442871)]
hv: [('ubiquitous', 0.626889705657959), ('poset', 0.6238789558410645), ('normalized', 0.6187589764595032)]
temperature: [('equilibrating', 0.7545353174209595), ('goss', 0.7445894479751587), ('prey', 0.7408083081245422)]
casting: [('preferable', 0.7417943477630615), ('adsorbent', 0.7343586683273315), ('rnas', 0.7246285080909729)]
oxygen: [('deforming', 0.8256338238716125), ('addressed', 0.7601260542869568), ('increasingly', 0.7358825206756592)]
oxide: [('waa', 0.686445951461792), ('maximizing', 0.6795468926429749), ('fratura', 0.6771236658096313)]
descaling: [('condense', 0.8986173272132874), ('circulating', 0.8259493112564087), ('instances', 0.763194739818573)]
sulfur: [

manganese: [('libr', 0.760863184928894), ('ridership', 0.676112949848175), ('purple', 0.6667639017105103)]
tin: [('phosphatase', 0.7505589723587036), ('poag', 0.742942214012146), ('thomson', 0.7399418950080872)]
magnetic: [('twitch', 0.9651629328727722), ('fs', 0.8618342280387878), ('singkat', 0.8033508658409119)]
interstitial: [('viscera', 0.8643798232078552), ('amnion', 0.8633195161819458), ('kidneys', 0.8406178951263428)]
twinning: [('incorporates', 0.6968389749526978), ('entropy', 0.6859356164932251), ('maintenance', 0.6761112213134766)]
martensitic: [('gfr', 0.710545003414154), ('connect', 0.704312264919281), ('cop', 0.6949915289878845)]
hydrogen: [('pneumothorax', 0.734032392501831), ('steering', 0.7162342667579651), ('profiling', 0.6983475089073181)]
vanadium: [('antiadherent', 0.6644886136054993), ('photocatalytic', 0.6525822281837463), ('true', 0.650658130645752)]
cooled: [('contained', 0.7765299081802368), ('could', 0.7331807613372803), ('nbs', 0.7321070432662964)]
carbon: [(

In [10]:
pip_dist = calculate_pip_distance(REF_DATA, REF_MATRIX, model)
print(pip_dist)

0.2932706116461968
