# 作業 : 調整 ELMo 模型的不同訓練參數

# [作業目標]
- 調整 ELMo 模型的不同參數, 分別觀察 loss 數據並比較

# [作業重點]
- 調整 ELMo 模型的不同訓練參數, 分別觀察 loss 數據並比較

# [參數說明]
- UNITS : LSTM 特徵維度，即每一筆訓練輸入單字的個數
- N_LAYERS : LSTM 堆疊的層數
- BATCH_SIZE : 訓練批次大小，即幾筆資料合併做一次倒傳遞
- LEARNING_RATE : 學習速率，影響收斂的快慢，須配合 BATCH_SIZE 調整

# 程式說明
- 程式採用 tensorflow2 / keras 寫作, 執行前請先安裝 tensorflow 2.0
- 本程式執行時, 請將 utils.py 與執行檔放置於同一目錄下
- 程式來源 : 莫煩Python-ELMo:一詞多義 https://mofanpy.com/tutorials/machine-learning/nlp/elmo/

In [1]:
# [Deep contextualized word representations](https://arxiv.org/pdf/1802.05365.pdf)
from tensorflow import keras
import tensorflow as tf
import utils    # this refers to utils.py in my [repo](https://github.com/MorvanZhou/NLP-Tutorials/)
import time
import os

class ELMo(keras.Model):
    def __init__(self, v_dim, emb_dim, units, n_layers, lr):
        super().__init__()
        self.n_layers = n_layers
        self.units = units

        # encoder
        self.word_embed = keras.layers.Embedding(
            input_dim=v_dim, output_dim=emb_dim,  # [n_vocab, emb_dim]
            embeddings_initializer=keras.initializers.RandomNormal(0., 0.001),
            mask_zero=True,
        )
        # forward lstm
        self.fs = [keras.layers.LSTM(units, return_sequences=True) for _ in range(n_layers)]
        self.f_logits = keras.layers.Dense(v_dim)
        # backward lstm
        self.bs = [keras.layers.LSTM(units, return_sequences=True, go_backwards=True) for _ in range(n_layers)]
        self.b_logits = keras.layers.Dense(v_dim)

        self.cross_entropy1 = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.cross_entropy2 = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.opt = keras.optimizers.Adam(lr)

    def call(self, seqs):
        embedded = self.word_embed(seqs)        # [n, step, dim]
        """
        0123    forward
        1234    forward predict
         1234   backward 
         0123   backward predict
        """
        mask = self.word_embed.compute_mask(seqs)
        fxs, bxs = [embedded[:, :-1]], [embedded[:, 1:]]
        for fl, bl in zip(self.fs, self.bs):
            fx = fl(
                fxs[-1], mask=mask[:, :-1], initial_state=fl.get_initial_state(fxs[-1])
            )           # [n, step-1, dim]
            bx = bl(
                bxs[-1], mask=mask[:, 1:], initial_state=bl.get_initial_state(bxs[-1])
            )  # [n, step-1, dim]
            fxs.append(fx)      # predict 1234
            bxs.append(tf.reverse(bx, axis=[1]))    # predict 0123
        return fxs, bxs

    def step(self, seqs):
        with tf.GradientTape() as tape:
            fxs, bxs = self.call(seqs)
            fo, bo = self.f_logits(fxs[-1]), self.b_logits(bxs[-1])
            loss = (self.cross_entropy1(seqs[:, 1:], fo) + self.cross_entropy2(seqs[:, :-1], bo))/2
        grads = tape.gradient(loss, self.trainable_variables)
        self.opt.apply_gradients(zip(grads, self.trainable_variables))
        return loss, (fo, bo)

    def get_emb(self, seqs):
        fxs, bxs = self.call(seqs)
        xs = [tf.concat((f[:, :-1, :], b[:, 1:, :]), axis=2).numpy() for f, b in zip(fxs, bxs)]
        for x in xs:
            print("layers shape=", x.shape)
        return xs


def train(model, data, step):
    t0 = time.time()
    for t in range(step):
        seqs = data.sample(BATCH_SIZE)
        loss, (fo, bo) = model.step(seqs)
        if t % 80 == 0:
            fp = fo[0].numpy().argmax(axis=1)
            bp = bo[0].numpy().argmax(axis=1)
            t1 = time.time()
            print(
                "\n\nstep: ", t,
                "| time: %.2f" % (t1 - t0),
                "| loss: %.3f" % loss.numpy(),
                "\n| tgt: ", " ".join([data.i2v[i] for i in seqs[0] if i != data.pad_id]),
                "\n| f_prd: ", " ".join([data.i2v[i] for i in fp if i != data.pad_id]),
                "\n| b_prd: ", " ".join([data.i2v[i] for i in bp if i != data.pad_id]),
                )
            t0 = t1
    os.makedirs("./visual/models/elmo", exist_ok=True)
    model.save_weights("./visual/models/elmo/model.ckpt")


def export_w2v(model, data):
    model.load_weights("./visual/models/elmo/model.ckpt")
    emb = model.get_emb(data.sample(4))
    print(emb)


if __name__ == "__main__":
    utils.set_soft_gpu(True)
    UNITS = 256
    N_LAYERS = 2
    BATCH_SIZE = 16
    LEARNING_RATE = 2e-3
    d = utils.MRPCSingle("./MRPC", rows=2000)
    print("num word: ", d.num_word)
    m = ELMo(d.num_word, emb_dim=UNITS, units=UNITS, n_layers=N_LAYERS, lr=LEARNING_RATE)
    train(m, d, 10000)
    export_w2v(m, d)

num word:  12880


step:  0 | time: 2.07 | loss: 9.463 
| tgt:  <GO> police arrested a <quote> potential suspect <quote> monday in the abduction of a <NUM>-year-old who was found safe after two days , the police chief said . <SEP> 
| f_prd:  acer acer mobility vern vern vern vern obtain outnumber outnumber outnumber outnumber hurting hurting trunk malice trunk atadero atadero 1980s expectancy expectancy expectancy idc idc supervisor supervisor supervisor supervisor supervisor supervisor supervisor supervisor supervisor supervisor supervisor supervisor 
| b_prd:  canon hornets canon anti-trust canon history spans transformed battalion battalion .... .... .... turquoise tool turquoise equal-weight gigabyte gigabyte winkenwerder winkenwerder winkenwerder winkenwerder banning lapel trade fraudulently rival


step:  80 | time: 133.36 | loss: 6.524 
| tgt:  <GO> <quote> if it ain 't broke , don 't fix it , <quote> said senate minority leader tom daschle , a south dakota democrat . <SEP> 
| f



step:  1520 | time: 119.69 | loss: 4.455 
| tgt:  <GO> commenting on the firing today , ms. novikova said that there was no standard weight for ballerinas but that ms. volochkova <quote> is bigger than others . <quote> <SEP> 
| f_prd:  the , the <NUM> <NUM> , the <NUM> , , the , not <NUM> <NUM> , the , the the <NUM> of the the <quote> <quote> said <SEP> <SEP> 
| b_prd:  <GO> <GO> <GO> the <GO> <GO> <GO> the <GO> <GO> <GO> <GO> <GO> the <GO> <GO> the , , , the , , the , the <NUM> . .


step:  1600 | time: 118.65 | loss: 4.767 
| tgt:  <GO> the report was found last week tucked inside a training manual that belonged to hicks . <SEP> 
| f_prd:  the dow of a the <NUM> to a the <NUM> , <NUM> <NUM> , the . <SEP> 
| b_prd:  <GO> the , , , the , , , the the , the <NUM> the <NUM> .


step:  1680 | time: 118.12 | loss: 4.702 
| tgt:  <GO> michael schiavo has argued that his wife never wanted to be kept alive artificially . <SEP> 
| f_prd:  the <NUM> , been the the <NUM> of been to the the to t



step:  3200 | time: 118.30 | loss: 3.043 
| tgt:  <GO> kodak expects earnings of <NUM> cents to <NUM> cents a share in the quarter . <SEP> 
| f_prd:  the , writer of the cents to $ cents a share . the new . <SEP> 
| b_prd:  <GO> , the <NUM> $ <NUM> , $ <NUM> , a , in a <NUM> .


step:  3280 | time: 119.15 | loss: 3.164 
| tgt:  <GO> bob hope , master of the one-liner and americas favourite comedian , died with a smile on his face yesterday just months after celebrating his 100th birthday . <SEP> 
| f_prd:  the richter , the 's the one-liner of <NUM> management comedian , who from the smile , the wife in , the from the the 100th birthday . <SEP> 
| b_prd:  <GO> <GO> <NUM> a <NUM> of the , a and be <NUM> , , of to <NUM> of would , and <NUM> , , in a stock <NUM> .


step:  3360 | time: 120.14 | loss: 2.896 
| tgt:  <GO> mr berlusconi is accused of bribing judges to influence a takeover battle in the 1980s involving sme , a state-owned food company . <SEP> 
| f_prd:  the schiavo was the 



step:  4800 | time: 124.74 | loss: 2.414 
| tgt:  <GO> security forces stormed the building , but <NUM> hostages were killed along with the attackers . <SEP> 
| f_prd:  the press jurors the building after including the million in killed in with secretly <NUM> . <SEP> 
| b_prd:  nasdaq 's he of the <NUM> , the the , been , , on the said .


step:  4880 | time: 124.74 | loss: 2.355 
| tgt:  <GO> mosel had been unable to present financial statements in time because of mergers among subsidiaries and a change of accountants , the company said yesterday . <SEP> 
| f_prd:  the was been unable to be de statements , velcade for his mergers a subsidiaries and the change network accountants , the company 's . . <SEP> 
| b_prd:  <GO> it have been expected to 's the <quote> the . part to was for , , the part in <NUM> , the <quote> year said .


step:  4960 | time: 126.41 | loss: 2.322 
| tgt:  <GO> the eu 's agriculture representative in washington said eu ministers were invited but canceled beca



step:  6400 | time: 153.12 | loss: 1.978 
| tgt:  <GO> an average residential customer paying $ <NUM> a year for electricity could see a savings of $ <NUM> annually . <SEP> 
| f_prd:  the artist residential customer paying the <NUM> million year and electricity to be <NUM> savings of <NUM> <NUM> million . <SEP> 
| b_prd:  <GO> the index a to to to , a was on , to is a percent to $ <NUM> said .


step:  6480 | time: 151.14 | loss: 1.836 
| tgt:  <GO> gerry kiely , a eu agriculture representative in washington , said eu ministers were invited but canceled because the union is closing talks on agricultural reform . <SEP> 
| f_prd:  the kiely , a eu agriculture representative in washington and said eu ministers were invited but canceled because the union is closing talks on agricultural reform . <SEP> 
| b_prd:  <GO> gerry <NUM> <GO> the 's agriculture <GO> in <NUM> he <GO> eu , were , , canceled said the , was on <NUM> on agricultural said .


step:  6560 | time: 155.00 | loss: 1.636 
|



step:  8000 | time: 145.23 | loss: 1.508 
| tgt:  <GO> when i talked to him last time , did i think it was the end-all ? <SEP> 
| f_prd:  the the think to the , week , he not think our was a end-all ? <SEP> 
| b_prd:  <GO> <quote> respectfully referring to said the ex-girlfriend , , i said he be the said .


step:  8080 | time: 144.26 | loss: 1.400 
| tgt:  <GO> i stand <NUM> percent by it , and i think that our intelligence services gave us the correct intelligence and information at the time , <quote> blair said . <SEP> 
| f_prd:  the would the percent of the 's and i think our intelligence intelligence services gave the the correct intelligence services information at <quote> time . <quote> blair said . <SEP> 
| b_prd:  <quote> we about <NUM> weight pushed . , <quote> i rejection is 's 's i to on the material shuddered update aimed in the said , <quote> he said .


step:  8160 | time: 145.77 | loss: 1.393 
| tgt:  <GO> <quote> it 's a new beginning as far as the courts , crown pro



step:  9520 | time: 128.12 | loss: 1.107 
| tgt:  <GO> mel gibson is negotiating with newmarket films to distribute his embattled biblical epic <quote> the passion of christ <quote> in the united states . <SEP> 
| f_prd:  the gibson 's negotiating with newmarket films to distribute the embattled biblical epic <quote> the passion of christ united are the united states , <SEP> 
| b_prd:  <GO> jean-marie it is veteran with disc continue to with his embattled biblical , of the behalf of satisfaction strangled in the united said .


step:  9600 | time: 133.24 | loss: 0.947 
| tgt:  <GO> brent crude for july delivery fell <NUM> cents to $ <NUM> a barrel on london 's international petroleum exchange . <SEP> 
| f_prd:  the crude for the delivery fell <NUM> cents to $ <NUM> a barrel on the 's international petroleum exchange . <SEP> 
| b_prd:  <GO> <GO> shares <GO> the shares rose <NUM> cents at $ as per barrel on google 's 's stock said .


step:  9680 | time: 140.54 | loss: 1.320 
| tgt:  <