# Embedding Training

This notebook demonstrate how to train embeddings:

- Train Wasserstein embeddings
- Train Kullback-Leibler embeddings

## Load packages and data

In [1]:
import os
import sys
from time import time
import pickle

import logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)

In [2]:
import numpy as np
import tensorflow as tf

from utils import *
from kl_net import *
from wass_net import *

load preprocessed data

In [3]:
with open('./data/vocab.pkl', 'rb') as infile:
    vocab = pickle.load(infile)
with open('./data/vocab2id.pkl', 'rb') as infile:
    vocab2id = pickle.load(infile)
with open('./data/pos_samples.pkl', 'rb') as infile:
    pos_samples = pickle.load(infile)
with open('./data/neg_samples.pkl', 'rb') as infile:
    neg_samples = pickle.load(infile)

logging.info("Load vocabulary and word pairs from local file.")

2019-05-21 23:16:33,029 - INFO - Load vocabulary and word pairs from local file.


In [4]:
file_name = 'WikiSmall'

X_train = np.array(list(pos_samples) + list(neg_samples))
y_train = np.array([1 for _ in pos_samples] + [0 for _ in neg_samples])

X_train.shape, y_train.shape

((12753368, 2), (12753368,))

## Wasserstein Embedding

define parameters

In [9]:
embed_dim = 32
ground_dim = 2
n_epochs = 1
batch_size = 2048
m = 10
lambd = 0.5

In [10]:
pre_trained_embeddings = None
logging.info("Running Wasserstein R{} embedding, embed dim={}".format(ground_dim, embed_dim))

embeddings, embed_distances = train_wass(X_train, y_train, vocab_size=len(vocab), 
                                    pre_trained_weights=pre_trained_embeddings, 
                                    dim=embed_dim, learning_rate=0.001, 
                                    n_epochs=n_epochs, 
                                    ground_dim=ground_dim, lambd=lambd, m=m,
                                    batch_size=batch_size)
# writing to local files
logging.info("Writing {}_WassR{}_{}_batch to local file".format(file_name, ground_dim, embed_dim))
np.savez('./results/{}_WassR{}_{}_batch'.format(file_name, ground_dim, embed_dim), 
         embeddings=embeddings, embed_distances=embed_distances)

2019-05-22 03:25:29,378 - INFO - Running Wasserstein R2 embedding, embed dim=32
2019-05-22 03:25:29,392 - INFO - Initialize embeddings by random.
2019-05-22 03:25:32,896 - INFO - Epoch No.1/1 - Batch No.1/6228 with 2048 samples: Loss 94503.703125
2019-05-22 03:25:32,897 - INFO - Storing weights (embeddings)
2019-05-22 03:25:47,906 - INFO - Epoch No.1/1 - Batch No.11/6228 with 2048 samples: Loss 95136.5078125
2019-05-22 03:26:03,092 - INFO - Epoch No.1/1 - Batch No.21/6228 with 2048 samples: Loss 92915.71875
2019-05-22 03:26:18,458 - INFO - Epoch No.1/1 - Batch No.31/6228 with 2048 samples: Loss 92993.671875
2019-05-22 03:26:33,812 - INFO - Epoch No.1/1 - Batch No.41/6228 with 2048 samples: Loss 91887.125
2019-05-22 03:26:49,253 - INFO - Epoch No.1/1 - Batch No.51/6228 with 2048 samples: Loss 92046.046875
2019-05-22 03:27:04,807 - INFO - Epoch No.1/1 - Batch No.61/6228 with 2048 samples: Loss 94020.8203125
2019-05-22 03:27:20,422 - INFO - Epoch No.1/1 - Batch No.71/6228 with 2048 sample

2019-05-22 03:48:53,733 - INFO - Epoch No.1/1 - Batch No.781/6228 with 2048 samples: Loss 79992.96875
2019-05-22 03:49:13,731 - INFO - Epoch No.1/1 - Batch No.791/6228 with 2048 samples: Loss 84297.6953125
2019-05-22 03:49:33,722 - INFO - Epoch No.1/1 - Batch No.801/6228 with 2048 samples: Loss 82839.5390625
2019-05-22 03:49:53,662 - INFO - Epoch No.1/1 - Batch No.811/6228 with 2048 samples: Loss 84220.671875
2019-05-22 03:50:13,938 - INFO - Epoch No.1/1 - Batch No.821/6228 with 2048 samples: Loss 80693.3359375
2019-05-22 03:50:34,037 - INFO - Epoch No.1/1 - Batch No.831/6228 with 2048 samples: Loss 80403.265625
2019-05-22 03:50:54,085 - INFO - Epoch No.1/1 - Batch No.841/6228 with 2048 samples: Loss 81749.03125
2019-05-22 03:51:14,147 - INFO - Epoch No.1/1 - Batch No.851/6228 with 2048 samples: Loss 81969.25
2019-05-22 03:51:34,342 - INFO - Epoch No.1/1 - Batch No.861/6228 with 2048 samples: Loss 79335.6015625
2019-05-22 03:51:54,970 - INFO - Epoch No.1/1 - Batch No.871/6228 with 2048

2019-05-22 04:16:28,477 - INFO - Epoch No.1/1 - Batch No.1571/6228 with 2048 samples: Loss 73917.4296875
2019-05-22 04:16:49,794 - INFO - Epoch No.1/1 - Batch No.1581/6228 with 2048 samples: Loss 72760.53125
2019-05-22 04:17:11,145 - INFO - Epoch No.1/1 - Batch No.1591/6228 with 2048 samples: Loss 72143.6875
2019-05-22 04:17:33,123 - INFO - Epoch No.1/1 - Batch No.1601/6228 with 2048 samples: Loss 74830.234375
2019-05-22 04:17:54,452 - INFO - Epoch No.1/1 - Batch No.1611/6228 with 2048 samples: Loss 73112.25
2019-05-22 04:18:16,114 - INFO - Epoch No.1/1 - Batch No.1621/6228 with 2048 samples: Loss 73446.6875
2019-05-22 04:18:37,561 - INFO - Epoch No.1/1 - Batch No.1631/6228 with 2048 samples: Loss 70440.6796875
2019-05-22 04:18:58,925 - INFO - Epoch No.1/1 - Batch No.1641/6228 with 2048 samples: Loss 72837.25
2019-05-22 04:19:20,389 - INFO - Epoch No.1/1 - Batch No.1651/6228 with 2048 samples: Loss 71210.15625
2019-05-22 04:19:41,810 - INFO - Epoch No.1/1 - Batch No.1661/6228 with 2048

2019-05-22 04:45:08,307 - INFO - Epoch No.1/1 - Batch No.2361/6228 with 2048 samples: Loss 67677.9296875
2019-05-22 04:45:30,860 - INFO - Epoch No.1/1 - Batch No.2371/6228 with 2048 samples: Loss 66227.0703125
2019-05-22 04:45:52,817 - INFO - Epoch No.1/1 - Batch No.2381/6228 with 2048 samples: Loss 67886.7890625
2019-05-22 04:46:14,842 - INFO - Epoch No.1/1 - Batch No.2391/6228 with 2048 samples: Loss 66220.4453125
2019-05-22 04:46:36,793 - INFO - Epoch No.1/1 - Batch No.2401/6228 with 2048 samples: Loss 67504.5859375
2019-05-22 04:46:58,806 - INFO - Epoch No.1/1 - Batch No.2411/6228 with 2048 samples: Loss 66058.125
2019-05-22 04:47:20,877 - INFO - Epoch No.1/1 - Batch No.2421/6228 with 2048 samples: Loss 67446.171875
2019-05-22 04:47:42,950 - INFO - Epoch No.1/1 - Batch No.2431/6228 with 2048 samples: Loss 67666.890625
2019-05-22 04:48:05,054 - INFO - Epoch No.1/1 - Batch No.2441/6228 with 2048 samples: Loss 67739.71875
2019-05-22 04:48:27,173 - INFO - Epoch No.1/1 - Batch No.2451/6

2019-05-22 05:14:05,898 - INFO - Epoch No.1/1 - Batch No.3141/6228 with 2048 samples: Loss 62983.20703125
2019-05-22 05:14:28,407 - INFO - Epoch No.1/1 - Batch No.3151/6228 with 2048 samples: Loss 65143.390625
2019-05-22 05:14:50,873 - INFO - Epoch No.1/1 - Batch No.3161/6228 with 2048 samples: Loss 64928.34375
2019-05-22 05:15:13,533 - INFO - Epoch No.1/1 - Batch No.3171/6228 with 2048 samples: Loss 63517.16796875
2019-05-22 05:15:35,836 - INFO - Epoch No.1/1 - Batch No.3181/6228 with 2048 samples: Loss 63268.09375
2019-05-22 05:15:58,298 - INFO - Epoch No.1/1 - Batch No.3191/6228 with 2048 samples: Loss 64200.29296875
2019-05-22 05:16:20,772 - INFO - Epoch No.1/1 - Batch No.3201/6228 with 2048 samples: Loss 63375.08984375
2019-05-22 05:16:43,153 - INFO - Epoch No.1/1 - Batch No.3211/6228 with 2048 samples: Loss 63544.53515625
2019-05-22 05:17:05,487 - INFO - Epoch No.1/1 - Batch No.3221/6228 with 2048 samples: Loss 63331.25
2019-05-22 05:17:28,529 - INFO - Epoch No.1/1 - Batch No.323

2019-05-22 05:44:00,540 - INFO - Epoch No.1/1 - Batch No.3931/6228 with 2048 samples: Loss 60236.5
2019-05-22 05:44:23,459 - INFO - Epoch No.1/1 - Batch No.3941/6228 with 2048 samples: Loss 59897.2578125
2019-05-22 05:44:46,283 - INFO - Epoch No.1/1 - Batch No.3951/6228 with 2048 samples: Loss 60910.15625
2019-05-22 05:45:09,451 - INFO - Epoch No.1/1 - Batch No.3961/6228 with 2048 samples: Loss 62643.89453125
2019-05-22 05:45:32,222 - INFO - Epoch No.1/1 - Batch No.3971/6228 with 2048 samples: Loss 60780.6328125
2019-05-22 05:45:55,154 - INFO - Epoch No.1/1 - Batch No.3981/6228 with 2048 samples: Loss 60000.12109375
2019-05-22 05:46:18,183 - INFO - Epoch No.1/1 - Batch No.3991/6228 with 2048 samples: Loss 60554.84765625
2019-05-22 05:46:41,023 - INFO - Epoch No.1/1 - Batch No.4001/6228 with 2048 samples: Loss 60668.28125
2019-05-22 05:46:41,023 - INFO - Storing weights (embeddings)
2019-05-22 05:47:03,948 - INFO - Epoch No.1/1 - Batch No.4011/6228 with 2048 samples: Loss 59156.140625
2

2019-05-22 06:14:03,696 - INFO - Epoch No.1/1 - Batch No.4711/6228 with 2048 samples: Loss 59649.359375
2019-05-22 06:14:27,118 - INFO - Epoch No.1/1 - Batch No.4721/6228 with 2048 samples: Loss 59763.84375
2019-05-22 06:14:50,319 - INFO - Epoch No.1/1 - Batch No.4731/6228 with 2048 samples: Loss 58714.46875
2019-05-22 06:15:13,806 - INFO - Epoch No.1/1 - Batch No.4741/6228 with 2048 samples: Loss 60771.9296875
2019-05-22 06:15:37,209 - INFO - Epoch No.1/1 - Batch No.4751/6228 with 2048 samples: Loss 58600.5625
2019-05-22 06:16:00,549 - INFO - Epoch No.1/1 - Batch No.4761/6228 with 2048 samples: Loss 59650.01953125
2019-05-22 06:16:23,846 - INFO - Epoch No.1/1 - Batch No.4771/6228 with 2048 samples: Loss 58745.25390625
2019-05-22 06:16:47,096 - INFO - Epoch No.1/1 - Batch No.4781/6228 with 2048 samples: Loss 59656.375
2019-05-22 06:17:10,370 - INFO - Epoch No.1/1 - Batch No.4791/6228 with 2048 samples: Loss 58259.73828125
2019-05-22 06:17:34,136 - INFO - Epoch No.1/1 - Batch No.4801/62

2019-05-22 06:44:39,811 - INFO - Epoch No.1/1 - Batch No.5491/6228 with 2048 samples: Loss 58800.83203125
2019-05-22 06:45:03,620 - INFO - Epoch No.1/1 - Batch No.5501/6228 with 2048 samples: Loss 56018.37890625
2019-05-22 06:45:27,347 - INFO - Epoch No.1/1 - Batch No.5511/6228 with 2048 samples: Loss 57597.546875
2019-05-22 06:45:51,041 - INFO - Epoch No.1/1 - Batch No.5521/6228 with 2048 samples: Loss 57519.265625
2019-05-22 06:46:14,833 - INFO - Epoch No.1/1 - Batch No.5531/6228 with 2048 samples: Loss 57819.4453125
2019-05-22 06:46:38,383 - INFO - Epoch No.1/1 - Batch No.5541/6228 with 2048 samples: Loss 59858.421875
2019-05-22 06:47:02,090 - INFO - Epoch No.1/1 - Batch No.5551/6228 with 2048 samples: Loss 57996.60546875
2019-05-22 06:47:25,710 - INFO - Epoch No.1/1 - Batch No.5561/6228 with 2048 samples: Loss 58382.0390625
2019-05-22 06:47:49,510 - INFO - Epoch No.1/1 - Batch No.5571/6228 with 2048 samples: Loss 59493.08984375
2019-05-22 06:48:13,402 - INFO - Epoch No.1/1 - Batch 

## KL Embedding

define parameters

In [17]:
lr = 0.001
embed_dim = 300
n_epochs = 1
batch_size = 2048
m = 10

In [None]:
pre_trained_embeddings = None
logging.info("Running KL embedding, embed dim={}".format(embed_dim))

embeddings, embed_distances = train_kl(X_train, y_train, vocab_size=len(vocab), 
                                    pre_trained_weights=pre_trained_embeddings, 
                                    dim=embed_dim, learning_rate=lr, 
                                    n_epochs=n_epochs, m=m,
                                    batch_size=batch_size)
# write to local files
logging.info("Writing {}_{}_{}_batch to local file".format(file_name, 'KL', embed_dim))
np.savez('./results/{}_{}_{}_batch'.format(file_name, 'KL', embed_dim), embeddings=embeddings, embed_distances=embed_distances)

2019-05-20 09:18:20,839 - INFO - Running KL embedding, embed dim=300
2019-05-20 09:18:21,107 - INFO - Initialize embeddings by random.
