In [2]:
import math
import os
from os.path import join
import time

from absl import app
from absl import flags
from absl import logging
from tqdm import tqdm
import numpy as np

from keras.optimizers import Adam

# from data import Embedding, MultiLanguageEmbedding, \
#     LazyIndexCorpus,  Word2vecIterator, BilbowaIterator

from data import *
from model import get_model, word2vec_loss, bilbowa_loss, strong_pair_loss, weak_pair_loss
import sys
sys.path.insert(0, '../../reference/eval')
from evaluate import Evaluator



In [3]:
emb0 = Embedding("./data_root/withctx.en-fr.en.50.1.txt")
emb1 = Embedding("./data_root/withctx.en-fr.fr.50.1.txt")
emb = MultiLanguageEmbedding(emb0, emb1)
vocab = emb.get_vocab()
emb_matrix = emb.get_emb()

ctxemb0 = Embedding("./data_root/withctx.en-fr.en.50.1.txt.ctx")
ctxemb1 = Embedding("./data_root/withctx.en-fr.fr.50.1.txt.ctx")
ctxemb = MultiLanguageEmbedding(ctxemb0, ctxemb1)
ctxvocab = ctxemb.get_vocab()
ctxemb_matrix = ctxemb.get_emb()

evaluator = Evaluator(emb0, emb1) 


100%|██████████| 995003/995003 [00:17<00:00, 56611.95 words/s]
100%|██████████| 432455/432455 [00:04<00:00, 88011.86 words/s] 
 77%|███████▋  | 763484/995003 [00:13<00:04, 54872.23 words/s]

KeyboardInterrupt: 

 77%|███████▋  | 763484/995003 [00:30<00:09, 25443.79 words/s]

In [57]:
strong, weak = read_pair()
strong_id, weak_id, l0_dict, l1_dict = pair2id(strong, weak, emb)

In [59]:
assert tuple(ctxvocab) == tuple(vocab)

In [62]:
mono_max_lines = 10000
mono0 = LazyIndexCorpus("./data_root/en_mono",
        max_lines=mono_max_lines)

mono1 = LazyIndexCorpus("./data_root/en_mono",mono_max_lines)

multi_max_lines = 10000
multi0 = LazyIndexCorpus("./data_root/en_multi",max_lines=multi_max_lines)
multi1 = LazyIndexCorpus("./data_root/fr_multi",max_lines=multi_max_lines)

In [64]:
mono0_unigram_table = mono0.get_unigram_table(vocab_size=len(vocab))
mono1_unigram_table = mono1.get_unigram_table(vocab_size=len(vocab))

In [67]:
emb_subsample = 1e-5
word2vec_negative_size = 10
word2vec_batch_size = 100000
mono0_iterator = Word2vecIterator(
        mono0,
        mono0_unigram_table,
        subsample=emb_subsample,
        window_size=word2vec_negative_size,
        negative_samples=word2vec_negative_size,
        batch_size=word2vec_batch_size,
    )

mono1_iterator = Word2vecIterator(
        mono1,
        mono1_unigram_table,
        subsample=emb_subsample,
        window_size=word2vec_negative_size,
        negative_samples=word2vec_negative_size,
        batch_size=word2vec_batch_size,
    )

bilbowa_sent_length = 50
bilbowa_batch_size = 100
multi_iterator = BilbowaIterator(
    multi0,
    multi1,
    mono0_unigram_table,
    mono1_unigram_table,
    subsample=emb_subsample,
    length=bilbowa_sent_length,
    batch_size=bilbowa_batch_size,
)

In [70]:
# strong pair iterator
strong_batch_size = 1000
strong_negative_size = 10
strong_pair_iterator = strong_pairIterator(
    strong_id,
    mono0_unigram_table,
    mono1_unigram_table,
    batch_size = strong_batch_size,
    negative_samples = strong_negative_size,
    l0_dict = l0_dict,
    l1_dict = l1_dict
)

# weak pair iterator
weak_batch_size = 3000
weak_negative_size = 10
weak_pair_iterator = weak_pairIterator(
    weak_id,
    mono0_unigram_table,
    mono1_unigram_table,
    batch_size=weak_batch_size,
    negative_samples=weak_negative_size,
    l0_dict=l0_dict,
    l1_dict=l1_dict
)


In [73]:
emb_dim = 50
encoder_desc_length = 15
(
    word2vec_model,
    bilbowa_model,
    strong_pair_model,
    weak_pair_model,
    word2vec_model_infer,
    bilbowa_model_infer,
    strong_pair_model_infer,
    weak_pair_model_infer
) = get_model(
    nb_word=len(vocab),
    dim=emb_dim,
    length=bilbowa_sent_length,
    desc_length=encoder_desc_length,
    word_emb_matrix=emb_matrix,
    context_emb_matrix=ctxemb_matrix,
)


OUTPUT Tensor("dot_2/MatMul:0", shape=(?, 1, 1), dtype=float32)
OUTPUT Tensor("flatten_3/Reshape:0", shape=(?, ?), dtype=float32)
OUTPUT Tensor("multiply_1/mul:0", shape=(?, ?), dtype=float32)


In [None]:
logging.info('word2vec_model.summary()')
word2vec_model.summary()
logging.info('bilbowa_model.summary()')
bilbowa_model.summary()
logging.info('strong_pair_model.summary()')
strong_pair_model.summary()


In [None]:
word2vec_lr = 0.001
word2vec_model.compile(
    optimizer=(Adam(amsgrad=True) if word2vec_lr < 0 else Adam(
        lr=word2vec_lr, amsgraword2vec_modeld=True)),
    loss=word2vec_loss)

bilbowa_lr = 0.001
bilbowa_model.compile(
    optimizer=(Adam(amsgrad=True) if bilbowa_lr < 0 else Adam(
        lr=bilbowa_lr, amsgrad=True)),
    loss=bilbowa_loss)

strong_pair_model_lr = 0.001
strong_pair_model.compile(
    optimizer=(Adam(amsgrad=True) if strong_pair_model_lr < 0 else Adam(
        lr=strong_pair_model_lr, amsgrad=True)),
    loss=strong_pair_loss)

weak_pair_model_lr = 0.001
weak_pair_model.compile(
    optimizer=(Adam(amsgrad=True) if weak_pair_model_lr < 0 else Adam(
        lr=weak_pair_model_lr, amsgrad=True)),
    loss=weak_pair_loss)


In [None]:
mono0_iter = mono0_iterator.fast2_iter()
mono1_iter = mono1_iterator.fast2_iter()
multi_iter = multi_iterator.iter()
strong_iter = strong_pair_iterator.strong_iter()
weak_iter = weak_pair_iterator.weak_iter()

keys = []
keys.append('mono0')
keys.append('mono1')
keys.append('multi')
keys.append('strong_pair')
keys.append('weak_pair')
keys = tuple(keys)

In [7]:
def dict_to_str(d):
    return '{' + ', '.join(
        ['%s: %s' % (key, d[key]) for key in sorted(d.keys())]) + '}'

comp_time = {key: 0.0 for key in keys}
load_time = {key: 0.0 for key in keys}
hit_count = {key: 0 for key in keys}
iter_info = {key: (0, 0) for key in keys}
last_loss = {key: 0.0 for key in keys}

def get_total_time():
    return {key: comp_time[key] + load_time[key] for key in keys}

global_start_time = time.time()
last_logging_time = 0.
loss_decay = 0.6
last_saving_time = 0.



In [None]:
while True:
    total_time = get_total_time()
    target_time = total_time
    min_time = min(target_time.values())
    next_key = [key for key in keys if target_time[key] == min_time][0]
    
    if next_key == 'mono0':
        start_time = time.time()
        (x, y), (epoch, instance) = next(mono0_iter)
        this_load_time = time.time() - start_time
        start_time = time.time()
        loss = word2vec_model.train_on_batch(x=x, y=y)
        this_comp_time = time.time() - start_time
    elif next_key == 'mono1':
        start_time = time.time()
        (x, y), (epoch, instance) = next(mono1_iter)
        this_load_time = time.time() - start_time
        start_time = time.time()
        loss = word2vec_model.train_on_batch(x=x, y=y)
        this_comp_time = time.time() - start_time
    elif next_key == 'multi':
        start_time = time.time()
        (x, y), (epoch, instance) = next(multi_iter)
        this_load_time = time.time() - start_time
        start_time = time.time()
        loss = bilbowa_model.train_on_batch(x=x, y=y)
        this_comp_time = time.time() - start_time
    elif next_key == 'strong_pair':
        start_time = time.time()
        (x, y), (epoch, instance) = next(strong_iter)
        this_load_time = time.time() - start_time
        start_time = time.time()
        loss = strong_pair_model.train_on_batch(x=x, y=y)
        this_comp_time = time.time() - start_time
    elif next_key == 'weak_pair':
        start_time = time.time()
        (x, y), (epoch, instance) = next(weak_iter)
        this_load_time = time.time() - start_time
        start_time = time.time()
        loss = weak_pair_model.train_on_batch(x=x, y=y)
        this_comp_time = time.time() - start_time
    else:
        assert False

#     assert not math.isnan(loss)

    comp_time[next_key] += this_comp_time
    load_time[next_key] += this_load_time
    hit_count[next_key] += 1
    iter_info[next_key] = (epoch, instance)
    last_loss[next_key] = loss if last_loss[next_key] == 0.0 else (
        last_loss[next_key] * loss_decay + loss * (1. - loss_decay))

    # exit if target is reached
    should_exit = False
    if FLAGS.max_mono_epochs > -1:
        if (iter_info['mono0'][0] >= FLAGS.max_mono_epochs
                and iter_info['mono1'][0] >= FLAGS.max_mono_epochs):
            should_exit = True

    if FLAGS.max_multi_epochs > -1:
        if (iter_info['multi'][0] >= FLAGS.max_multi_epochs):
            should_exit = True

    total_this_comp_time = time.time() - global_start_time
    if should_exit or (total_this_comp_time - last_logging_time >
                       FLAGS.logging_iterval):
        last_logging_time = total_this_comp_time
        # logging.info('Stats so far')
        # logging.info('next_key = %s', next_key)
        # logging.info('comp_time = %s', dict_to_str(comp_time))
        # logging.info('load_time = %s', dict_to_str(load_time))
        # logging.info('total_time = %s', dict_to_str(get_total_time()))
        logging.info('hit_count = %s', dict_to_str(hit_count))
        # logging.info('iter_info = %s', dict_to_str(iter_info))
        logging.info('last_loss = %s', dict_to_str(last_loss))
        
        #evaluate:
        evaluator.word_translation

    # save model
    if should_exit or (total_this_comp_time - last_saving_time >
                       FLAGS.saving_iterval):
        last_saving_time = total_this_comp_time
        logging.info('Saving models started.')
        tag = ''
        word2vec_model.save(join(FLAGS.model_root, tag + 'word2vec_model'))
        bilbowa_model.save(join(FLAGS.model_root, tag + 'bilbowa_model'))
        word2vec_model_infer.save(
            join(FLAGS.model_root, tag + 'word2vec_model_infer'))
        bilbowa_model_infer.save(
            join(FLAGS.model_root, tag + 'bilbowa_model_infer'))
        logging.info('Saving models done.')

    if should_exit:
        logging.info('Training target reached. Exit.')
        break
