### Notebook to demonstrate how to save and load embeddings 

In [1]:
import numpy as np 
from Feature2Vec import Feature2Vec 
from utils import * 
import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"    

SEED = 42
np.random.seed(seed = SEED)

# path = 'data/mcrae_feature_matrix.csv'
path = 'data/cslb_feature_matrix.csv'

print('Building feature2vec')
model = Feature2Vec(path = path)

Using TensorFlow backend.


Building feature2vec


In [2]:
shuffle = np.random.permutation(len(model.concepts))
train_concepts = list(np.asarray(model.concepts)[shuffle])
model.set_vocabulary(train_words = train_concepts)

In [3]:
print('Training feature2vec')
model.train(verbose = 1, epochs = 20, lr = 5e-3, negative_samples = 20)
print('')

Training feature2vec
Epoch: 19 Loss: 0.004500256059271702


In [4]:
# save vectors 
model.save('embeddings/cslb_embeddings.txt')

In [5]:
# show top related concepts
model.rank_neighbours(model.fvector('has_keys'), 5)

array([['0.3063240021408405', 'harpsichord'],
       ['0.2824758151071688', 'clarinet'],
       ['0.27327723302755025', 'typewriter'],
       ['0.2683656110461591', 'piano'],
       ['0.25110279449164563', 'organ_(musical_instrument)']],
      dtype='<U26')

In [6]:
model.feature_vectors.shape

(2725, 300)

In [7]:
# randomize vectors 
model.feature_vectors = np.random.rand(2526, 300)

In [8]:
# show top related concepts no longer make sense 
model.rank_neighbours(model.fvector('has_keys'), 5)

array([['0.17161331371219862', 'cape'],
       ['0.15091490064481322', 'certificate'],
       ['0.12758664586654367', 'willow'],
       ['0.12429329657070862', 'tent'],
       ['0.12090653818457522', 'dates']], dtype='<U19')

In [9]:
# reload in features 
model.load('embeddings/cslb_embeddings.txt')

In [10]:
# vectors are restored
model.rank_neighbours(model.fvector('has_keys'), 5)

array([['0.30632400211416344', 'harpsichord'],
       ['0.28247581514937375', 'clarinet'],
       ['0.27327723307175084', 'typewriter'],
       ['0.26836561105087015', 'piano'],
       ['0.2511027944505316', 'organ_(musical_instrument)']], dtype='<U26')