# 作業 : 調整 ELMo 模型的不同訓練參數

# [作業目標]
- 調整 ELMo 模型的不同參數, 分別觀察 loss 數據並比較

# [作業重點]
- 調整 ELMo 模型的不同訓練參數, 分別觀察 loss 數據並比較

# [參數說明]
- UNITS : LSTM 特徵維度，即每一筆訓練輸入單字的個數
- N_LAYERS : LSTM 堆疊的層數
- BATCH_SIZE : 訓練批次大小，即幾筆資料合併做一次倒傳遞
- LEARNING_RATE : 學習速率，影響收斂的快慢，須配合 BATCH_SIZE 調整

# 程式說明
- 程式採用 tensorflow2 / keras 寫作, 執行前請先安裝 tensorflow 2.0
- 本程式執行時, 請將 utils.py 與執行檔放置於同一目錄下
- 程式來源 : 莫煩Python-ELMo:一詞多義 https://mofanpy.com/tutorials/machine-learning/nlp/elmo/

In [1]:
# [Deep contextualized word representations](https://arxiv.org/pdf/1802.05365.pdf)
from tensorflow import keras
import tensorflow as tf
import utils    # this refers to utils.py in my [repo](https://github.com/MorvanZhou/NLP-Tutorials/)
import time
import os

class ELMo(keras.Model):
    def __init__(self, v_dim, emb_dim, units, n_layers, lr):
        super().__init__()
        self.n_layers = n_layers
        self.units = units

        # encoder
        self.word_embed = keras.layers.Embedding(
            input_dim=v_dim, output_dim=emb_dim,  # [n_vocab, emb_dim]
            embeddings_initializer=keras.initializers.RandomNormal(0., 0.001),
            mask_zero=True,
        )
        # forward lstm
        self.fs = [keras.layers.LSTM(units, return_sequences=True) for _ in range(n_layers)]
        self.f_logits = keras.layers.Dense(v_dim)
        # backward lstm
        self.bs = [keras.layers.LSTM(units, return_sequences=True, go_backwards=True) for _ in range(n_layers)]
        self.b_logits = keras.layers.Dense(v_dim)

        self.cross_entropy1 = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.cross_entropy2 = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.opt = keras.optimizers.Adam(lr)

    def call(self, seqs):
        embedded = self.word_embed(seqs)        # [n, step, dim]
        """
        0123    forward
        1234    forward predict
         1234   backward 
         0123   backward predict
        """
        mask = self.word_embed.compute_mask(seqs)
        fxs, bxs = [embedded[:, :-1]], [embedded[:, 1:]]
        for fl, bl in zip(self.fs, self.bs):
            fx = fl(
                fxs[-1], mask=mask[:, :-1], initial_state=fl.get_initial_state(fxs[-1])
            )           # [n, step-1, dim]
            bx = bl(
                bxs[-1], mask=mask[:, 1:], initial_state=bl.get_initial_state(bxs[-1])
            )  # [n, step-1, dim]
            fxs.append(fx)      # predict 1234
            bxs.append(tf.reverse(bx, axis=[1]))    # predict 0123
        return fxs, bxs

    def step(self, seqs):
        with tf.GradientTape() as tape:
            fxs, bxs = self.call(seqs)
            fo, bo = self.f_logits(fxs[-1]), self.b_logits(bxs[-1])
            loss = (self.cross_entropy1(seqs[:, 1:], fo) + self.cross_entropy2(seqs[:, :-1], bo))/2
        grads = tape.gradient(loss, self.trainable_variables)
        self.opt.apply_gradients(zip(grads, self.trainable_variables))
        return loss, (fo, bo)

    def get_emb(self, seqs):
        fxs, bxs = self.call(seqs)
        xs = [tf.concat((f[:, :-1, :], b[:, 1:, :]), axis=2).numpy() for f, b in zip(fxs, bxs)]
        for x in xs:
            print("layers shape=", x.shape)
        return xs


def train(model, data, step):
    t0 = time.time()
    for t in range(step):
        seqs = data.sample(BATCH_SIZE)
        loss, (fo, bo) = model.step(seqs)
        if t % 80 == 0:
            fp = fo[0].numpy().argmax(axis=1)
            bp = bo[0].numpy().argmax(axis=1)
            t1 = time.time()
            print(
                "\n\nstep: ", t,
                "| time: %.2f" % (t1 - t0),
                "| loss: %.3f" % loss.numpy(),
                "\n| tgt: ", " ".join([data.i2v[i] for i in seqs[0] if i != data.pad_id]),
                "\n| f_prd: ", " ".join([data.i2v[i] for i in fp if i != data.pad_id]),
                "\n| b_prd: ", " ".join([data.i2v[i] for i in bp if i != data.pad_id]),
                )
            t0 = t1
    os.makedirs("./visual/models/elmo", exist_ok=True)
    model.save_weights("./visual/models/elmo/model.ckpt")


def export_w2v(model, data):
    model.load_weights("./visual/models/elmo/model.ckpt")
    emb = model.get_emb(data.sample(4))
    print(emb)


if __name__ == "__main__":
    utils.set_soft_gpu(True)
    UNITS = 256
    N_LAYERS = 2
    BATCH_SIZE = 16
    LEARNING_RATE = 2e-3
    d = utils.MRPCSingle("./MRPC", rows=2000)
    print("num word: ", d.num_word)
    m = ELMo(d.num_word, emb_dim=UNITS, units=UNITS, n_layers=N_LAYERS, lr=LEARNING_RATE)
    train(m, d, 10000)
    export_w2v(m, d)

num word:  12880


step:  0 | time: 5.31 | loss: 9.463 
| tgt:  <GO> <quote> i 'm taking his office , and we 're gonna keep on building , ' ' he vowed . <quote> <SEP> 
| f_prd:  nursery occurring occurring backlash putnam putnam sec sec sec incidence incidence clinton clinton sunnyvale lopez lopez nancy smeared smeared smeared smeared fuelled fuelled fuelled fuelled fuelled fuelled fuelled fuelled fuelled fuelled fuelled fuelled fuelled fuelled fuelled fuelled 
| b_prd:  velasco velasco velasco harmless mentally outlook highlighting highlighting highlighting highlighting harmless laotian laotian laotian laotian dallager arguments sig transit rohr pie epicentre


step:  80 | time: 85.50 | loss: 6.390 
| tgt:  <GO> investigators were searching his home in muenster in the presence of his wife when news of his death came , prosecutor wolfgang schweer said . <SEP> 
| f_prd:   
| b_prd:  the the the the the the the the the the the the the the the the the the the the the the the the the the <



step:  1760 | time: 93.61 | loss: 4.411 
| tgt:  <GO> chief financial officer andy bryant has said that hike had a greater affect volume than officials expected . <SEP> 
| f_prd:  the said industrial , the , not <NUM> <NUM> , not <quote> of <NUM> of the . . <SEP> 
| b_prd:  <GO> <GO> <GO> <GO> <GO> , , , the , , the the , , , the <NUM> .


step:  1840 | time: 97.22 | loss: 4.207 
| tgt:  <GO> the swiss franc rose three quarters of a percent against the dollar and was last at <NUM> to the greenback . <SEP> 
| f_prd:  the dow jones of <NUM> <NUM> , the first , the new , be not than the , the end . <SEP> 
| b_prd:  <GO> <GO> <GO> the the the , , <NUM> <NUM> , the , , , , the , , the <NUM> .


step:  1920 | time: 88.59 | loss: 3.951 
| tgt:  <GO> a lawsuit has been filed in an attempt to block the removal of the ten commandments monument from the building . <SEP> 
| f_prd:  the dow , been been the the first , the , first of the court . not . the court . <SEP> 
| b_prd:  <GO> <GO> <GO> th



step:  3600 | time: 92.11 | loss: 3.138 
| tgt:  <GO> jacob has pushed consolidation for years , but he has said many communities , especially rural ones , have opposed it . <SEP> 
| f_prd:  the said been consolidation in the , but he said not it communities , which the ones . said not the . <SEP> 
| b_prd:  <GO> he has we <GO> <NUM> said <GO> said , , of two <NUM> , of the <NUM> to , said <NUM> .


step:  3680 | time: 89.63 | loss: 2.645 
| tgt:  <GO> he claims it may seem unrealistic only because little effort has been devoted to the concept . <SEP> 
| f_prd:  the said the would have to minds to the gamble to been devoted to the concept . <SEP> 
| b_prd:  <GO> he <GO> at we , , said the the , have be , in the said .


step:  3760 | time: 89.09 | loss: 2.552 
| tgt:  <GO> it was predicted to become a category i hurricane overnight . <SEP> 
| f_prd:  the has a the charing the category of hospital proceedings penalty <SEP> 
| b_prd:  <GO> it was , not in a of the on said .


step:  38



step:  5040 | time: 85.00 | loss: 2.111 
| tgt:  <GO> the talk , however , has been downplayed by pbl which said it would focus only on smaller purchases that were immediately earnings and cash flow-accretive . <SEP> 
| f_prd:  the dow was the , the not arrested by pbl <quote> said the would be the to the purchases . the immediately immediately . cash flow-accretive . <SEP> 
| b_prd:  <GO> the however , . it has was downplayed of , , said it will for <NUM> of the purchases , and with fair and <quote> <NUM> .


step:  5120 | time: 86.17 | loss: 2.242 
| tgt:  <GO> lord of the rings director peter jackson and longtime companion fran walsh . <SEP> 
| f_prd:  the of the oregon of of most was longtime companion fran walsh . <SEP> 
| b_prd:  <GO> head <GO> the athletic j.d. , , and and said charles <NUM> .


step:  5200 | time: 84.95 | loss: 1.784 
| tgt:  <GO> the mta argued it needed to raise fares to close a two-year deficit it estimated , at different times , to be $ <NUM> million or $



step:  6560 | time: 82.42 | loss: 1.330 
| tgt:  <GO> ms stewart , the chief executive , was not expected to attend . <SEP> 
| f_prd:  the stewart , the chief executive , said not bad to attend . <SEP> 
| b_prd:  <GO> robert mitchell , and chief villepin , was is declined to said .


step:  6640 | time: 80.12 | loss: 1.361 
| tgt:  <GO> <quote> we have found the smoking gun , <quote> investigating board member scott hubbard said . <SEP> 
| f_prd:  the i have a whether most gun and <quote> said spencer nor spencer spencer said . <SEP> 
| b_prd:  <GO> <quote> we probably pass the a preparedness , , the an <quote> <quote> he said .


step:  6720 | time: 79.78 | loss: 1.445 
| tgt:  <GO> there was no way the man could hear him , but he turned and mouthed something . <SEP> 
| f_prd:  the was no immediate the most on hear him , but he turned in mouthed something . <SEP> 
| b_prd:  <GO> there <GO> the said the it could hear stabilization , that he typhoid and mouthed said .


step:  6800 | 



step:  8080 | time: 78.98 | loss: 1.113 
| tgt:  <GO> in millville yesterday , mayor james quinn ordered all city flags flown at half-staff for the next <NUM> days . <SEP> 
| f_prd:  the <NUM> trading , mayor james quinn ordered the of flags flown at half-staff for the next two days . <SEP> 
| b_prd:  <GO> in millville unsung , mayor james quinn by all city flags flown in half-staff in the past <NUM> <NUM> .


step:  8160 | time: 79.11 | loss: 1.202 
| tgt:  <GO> an hour later israeli attack helicopters rained missiles on a car in gaza city , killing seven people , palestinian sources said . <SEP> 
| f_prd:  the hour , israeli judge helicopters rained missiles on the car as gaza city , killing seven people , palestinian sources said . <SEP> 
| b_prd:  <GO> an hour an israeli to helicopters rained missiles in a kept in gaza africa , including <NUM> baghdad , , he <NUM> .


step:  8240 | time: 79.37 | loss: 1.270 
| tgt:  <GO> peterson , <NUM> , was arrested in la jolla april <NUM> aft



step:  9520 | time: 79.74 | loss: 0.887 
| tgt:  <GO> the two democrats on the five-member fcc panel held a news conference to sway opinion against powell . <SEP> 
| f_prd:  the dow forms on the five-member fcc panel held file question conference to sway opinion against powell . <SEP> 
| b_prd:  <GO> the two democrats on the five-member fcc also from a news due to sway opinion against said .


step:  9600 | time: 80.11 | loss: 0.810 
| tgt:  <GO> mosel was unable to present financial statements in time due to mergers among subsidiaries and a change of accountants , the company said on monday . <SEP> 
| f_prd:  the was unable to turn financial statements in a by to mergers among subsidiaries and change change of accountants , the company said monday monday . <SEP> 
| b_prd:  <GO> mosel were unable to present financial statements to , due to mergers among subsidiaries and a change of accountants , the company <NUM> on said .


step:  9680 | time: 79.30 | loss: 0.845 
| tgt:  <GO> oracl

In [2]:
if __name__ == "__main__":
    utils.set_soft_gpu(True)
    UNITS = 256
    N_LAYERS = 2
    BATCH_SIZE = 32
    LEARNING_RATE = 1e-3
    d = utils.MRPCSingle("./MRPC", rows=2000)
    print("num word: ", d.num_word)
    m = ELMo(d.num_word, emb_dim=UNITS, units=UNITS, n_layers=N_LAYERS, lr=LEARNING_RATE)
    train(m, d, 10000)
    export_w2v(m, d)

num word:  12880


step:  0 | time: 2.00 | loss: 9.463 
| tgt:  <GO> the current plan is to release videotapes of the sessions on friday , after the review , said jim landale , a tribunal spokesman . <SEP> 
| f_prd:  trapped until until until until processor session session greatest moroccan moroccan guards guards guards guards xbox xbox xbox mainly mainly mainly mainly mainly mainly mainly mainly mainly mainly mainly mainly mainly mainly mainly mainly mainly mainly mainly 
| b_prd:  q into into into blowtorch blowtorch door door entrance entrance entrance entrance entrance entrance revolutions chorale stand pro-kremlin lax reassigned alert : littleton littleton capps capps


step:  80 | time: 112.08 | loss: 6.584 
| tgt:  <GO> <quote> we continue to work with the saudis on this , but they did not , as of the time of this tragic event , provide the additional security we requested . <quote> <SEP> 
| f_prd:   
| b_prd:  , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , .



step:  1520 | time: 112.42 | loss: 4.757 
| tgt:  <GO> at <NUM> a.m. ( <NUM> gmt ) , the <NUM>-year note us10yt = rr was up <NUM> / <NUM> for a yield of <NUM> percent from <NUM> percent wednesday . <SEP> 
| f_prd:  the the , , <NUM> , , , <NUM> <NUM> <NUM> <NUM> <NUM> <NUM> <NUM> <NUM> <NUM> , <NUM> . <NUM> <NUM> . <NUM> . . <NUM> . . . <SEP> 
| b_prd:  <GO> the <NUM> <GO> the <NUM> <NUM> <NUM> <GO> the , , <NUM> <NUM> , , $ <NUM> $ <NUM> of the <NUM> or <NUM> <NUM> $ <NUM> , <NUM> . <SEP> <SEP> <SEP> <SEP> <SEP> <SEP>


step:  1600 | time: 111.60 | loss: 4.712 
| tgt:  <GO> trading of echostar 's stock closed tuesday at a <NUM>-week high of $ <NUM> , up $ <NUM> . <SEP> 
| f_prd:  the , the , <NUM> <NUM> in in <NUM> <NUM> <NUM> to the <NUM> percent <NUM> <NUM> <NUM> . <SEP> 
| b_prd:  <GO> <GO> <GO> , the the , , of to in <NUM> to $ percent , to in <NUM> . <SEP> <SEP> <SEP> <SEP> <SEP> <SEP> <SEP> <SEP> <SEP> <SEP> <SEP> <SEP> <SEP> <SEP> <SEP> <SEP> <SEP>


step:  1680 | time: 111.3



step:  3120 | time: 111.00 | loss: 3.661 
| tgt:  <GO> with a wry smile , mr. bush replied , <quote> you 're looking pretty young these days . <quote> <SEP> 
| f_prd:  the the year <NUM> , the bush said to <quote> the 're 't <quote> <quote> , than . <SEP> <SEP> 
| b_prd:  <GO> <GO> a to said <GO> <GO> , said , <quote> we have and and of <NUM> child . .


step:  3200 | time: 111.85 | loss: 3.666 
| tgt:  <GO> neither iowa state athletic director bruce van de velde nor morgan could be reached for comment . <SEP> 
| f_prd:  the the company said of of <quote> , <quote> to the and be reached to the . <SEP> 
| b_prd:  <GO> <GO> the the , secretary the i john , that it to be school to <NUM> .


step:  3280 | time: 111.06 | loss: 3.583 
| tgt:  <GO> a kayaker found dr. scribner 's body floating near the doctor 's houseboat in portage bay , where he was eating lunch when his wife , ethel , left for an appointment . <SEP> 
| f_prd:  the technology-laced <NUM> the tom and chief of in the united



step:  4720 | time: 111.19 | loss: 2.636 
| tgt:  <GO> the shooting happened at <NUM> : <NUM> a.m. in the living room of the home the extended family shared on the city 's westside . <SEP> 
| f_prd:  the company <NUM> in <NUM> million <NUM> , in the new room of the second that extended of shared in the government 's westside . <SEP> 
| b_prd:  <GO> the shootings <GO> , <NUM> about <NUM> was of a <NUM> some of the of the extended and <quote> on the photojournalist the <NUM> .


step:  4800 | time: 110.95 | loss: 2.449 
| tgt:  <GO> after their arrests , sources said the men admitted they were smuggled into washington state from canada in july . <SEP> 
| f_prd:  the the arrests , the said the governor were to would no to the , , the to the <NUM> <SEP> 
| b_prd:  <GO> <GO> <GO> <NUM> <GO> , said the and where , be britain a the , in <NUM> in said .


step:  4880 | time: 111.84 | loss: 2.480 
| tgt:  <GO> a dod team is on site to determine how this happened and what needs to be done to f



step:  6240 | time: 111.25 | loss: 1.641 
| tgt:  <GO> in addition , juries in both state and federal cases have become increasingly reluctant to impose the death penalty . <SEP> 
| f_prd:  the the , the in the companies , have support in are further reluctant to impose the death . . <SEP> 
| b_prd:  <GO> in <GO> <GO> was of the , and other would have have that reluctant to impose the death said .


step:  6320 | time: 111.26 | loss: 1.844 
| tgt:  <GO> shares of san diego-based jack in the box closed at $ <NUM> , up <NUM> cents , or <NUM> percent , on the new york stock exchange . <SEP> 
| f_prd:  the of lendingtree diego-based rose at the <NUM> dropped at $ <NUM> , or <NUM> cents , or <NUM> percent , to the new york stock exchange . <SEP> 
| b_prd:  <GO> shares of san diego-based <GO> of the shares closed at $ <NUM> was or <NUM> cents , or <NUM> <NUM> <NUM> on the new york stock said .


step:  6400 | time: 111.10 | loss: 1.840 
| tgt:  <GO> the company has agreed terms on the purc



step:  7680 | time: 111.99 | loss: 1.466 
| tgt:  <GO> united airlines plans to become the first domestic airline to offer e-mail on all its domestic flights by the end of the year , the company announced yesterday . <SEP> 
| f_prd:  the airways have to be the first time to to charity subscribers on the of domestic flights by the end of the war . becoming company said . . <SEP> 
| b_prd:  <GO> united he tried to for the the the is to by e-mail for with its domestic flights by the end in last accountants , the it said said .


step:  7760 | time: 111.81 | loss: 1.318 
| tgt:  <GO> referring to a muslim fighter in somalia , boykin said that <quote> my god was bigger than his . <SEP> 
| f_prd:  the to the muslim fighter for somalia , boykin said , he it god was bigger than a family <SEP> 
| b_prd:  <GO> referring is a muslim fighter in somalia , company said , that my god was bigger than said .


step:  7840 | time: 112.17 | loss: 1.327 
| tgt:  <GO> adolescent specialist dr. michael co



step:  9120 | time: 112.23 | loss: 1.074 
| tgt:  <GO> the standard & poor 's retail index < .rlx > was up more than <NUM> percent . <SEP> 
| f_prd:  the technology-laced & poor 's <NUM> index < .rlx > was up <NUM> than <NUM> percent . <SEP> 
| b_prd:  <GO> the standard & poor 's <NUM> index < .rlx , giving rose more from <NUM> said .


step:  9200 | time: 112.11 | loss: 1.069 
| tgt:  <GO> the number of passengers aboard was not known and may never be , since ferry operators rarely keep full passenger lists . <SEP> 
| f_prd:  the company of the were was in known to undergraduate never be approved ferry <NUM> operators rarely be in . lists . <SEP> 
| b_prd:  <GO> the prisoner of <NUM> it was is known explorer will to <NUM> iraq and ferry have rarely and medical passenger said .


step:  9280 | time: 112.44 | loss: 1.032 
| tgt:  <GO> but close wondered whether the package would be worth the cost of licensing the third-party software , along with salesforce.com 's rental price . <SEP>

當batch_size 由16變為 32, learning rate由0.002 變為0.001時, 雖訓練速度較快, 但loss變為較高