# 作業 : 調整 Bert 模型的不同訓練參數

# [作業目標]
- 調整 Bert 模型的不同參數, 分別觀察 loss 數據並比較

# [作業重點]
- 調整 Bert 模型的不同訓練參數, 分別觀察 loss 數據並比較

# [參數說明]
- MODEL_DIM : Attention 特徵維度，即每一筆訓練輸入單字的個數
- N_LAYER : Attention 堆疊的層數
- LEARNING_RATE : 學習速率，影響收斂的快慢
- MASK_RATE : 掩碼比例(介於 0 到 0.5 間, 建議值 0.15)

# 程式說明
- 程式採用 tensorflow2 / keras 寫作, 執行前請先安裝 tensorflow 2.0
- 本程式執行時, 請將 utils.py / transformer.py / GPT.py 等三個檔案與執行檔放置於同一目錄下
- 程式來源 : 莫煩Python-BERT:雙向語言模型 https://mofanpy.com/tutorials/machine-learning/nlp/bert/

In [1]:
# [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf)
import numpy as np
import tensorflow as tf
import utils    # this refers to utils.py in my [repo](https://github.com/MorvanZhou/NLP-Tutorials/)
import time
from GPT import GPT
import os
import pickle


class BERT(GPT):
    def __init__(self, model_dim, max_len, n_layer, n_head, n_vocab, lr, max_seg=3, drop_rate=0.1, padding_idx=0):
        super().__init__(model_dim, max_len, n_layer, n_head, n_vocab, lr, max_seg, drop_rate, padding_idx)
        # I think task emb is not necessary for pretraining,
        # because the aim of all tasks is to train a universal sentence embedding
        # the body encoder is the same across all tasks,
        # and different output layer defines different task just like transfer learning.
        # finetuning replaces output layer and leaves the body encoder unchanged.

        # self.task_emb = keras.layers.Embedding(
        #     input_dim=n_task, output_dim=model_dim,  # [n_task, dim]
        #     embeddings_initializer=tf.initializers.RandomNormal(0., 0.01),
        # )

    def step(self, seqs, segs, seqs_, loss_mask, nsp_labels):
        with tf.GradientTape() as tape:
            mlm_logits, nsp_logits = self.call(seqs, segs, training=True)
            mlm_loss_batch = tf.boolean_mask(self.cross_entropy(seqs_, mlm_logits), loss_mask)
            mlm_loss = tf.reduce_mean(mlm_loss_batch)
            nsp_loss = tf.reduce_mean(self.cross_entropy(nsp_labels, nsp_logits))
            loss = mlm_loss + 0.2 * nsp_loss
            grads = tape.gradient(loss, self.trainable_variables)
            self.opt.apply_gradients(zip(grads, self.trainable_variables))
        return loss, mlm_logits

    def mask(self, seqs):
        mask = tf.cast(tf.math.equal(seqs, self.padding_idx), tf.float32)
        return mask[:, tf.newaxis, tf.newaxis, :]  # [n, 1, 1, step]


def _get_loss_mask(len_arange, seq, pad_id):
    rand_id = np.random.choice(len_arange, size=max(2, int(MASK_RATE * len(len_arange))), replace=False)
    loss_mask = np.full_like(seq, pad_id, dtype=np.bool)
    loss_mask[rand_id] = True
    return loss_mask[None, :], rand_id


def do_mask(seq, len_arange, pad_id, mask_id):
    loss_mask, rand_id = _get_loss_mask(len_arange, seq, pad_id)
    seq[rand_id] = mask_id
    return loss_mask


def do_replace(seq, len_arange, pad_id, word_ids):
    loss_mask, rand_id = _get_loss_mask(len_arange, seq, pad_id)
    seq[rand_id] = np.random.choice(word_ids, size=len(rand_id))
    return loss_mask


def do_nothing(seq, len_arange, pad_id):
    loss_mask, _ = _get_loss_mask(len_arange, seq, pad_id)
    return loss_mask


def random_mask_or_replace(data, arange, batch_size):
    seqs, segs, xlen, nsp_labels = data.sample(batch_size)
    seqs_ = seqs.copy()
    p = np.random.random()
    if p < 0.7:
        # mask
        loss_mask = np.concatenate(
            [do_mask(
                seqs[i],
                np.concatenate((arange[:xlen[i, 0]], arange[xlen[i, 0] + 1:xlen[i].sum() + 1])),
                data.pad_id,
                data.v2i["<MASK>"]) for i in range(len(seqs))], axis=0)
    elif p < 0.85:
        # do nothing
        loss_mask = np.concatenate(
            [do_nothing(
                seqs[i],
                np.concatenate((arange[:xlen[i, 0]], arange[xlen[i, 0] + 1:xlen[i].sum() + 1])),
                data.pad_id) for i in range(len(seqs))], axis=0)
    else:
        # replace
        loss_mask = np.concatenate(
            [do_replace(
                seqs[i],
                np.concatenate((arange[:xlen[i, 0]], arange[xlen[i, 0] + 1:xlen[i].sum() + 1])),
                data.pad_id,
                data.word_ids) for i in range(len(seqs))], axis=0)
    return seqs, segs, seqs_, loss_mask, xlen, nsp_labels


def train(model, data, step=10000, name="bert"):
    t0 = time.time()
    arange = np.arange(0, data.max_len)
    for t in range(step):
        seqs, segs, seqs_, loss_mask, xlen, nsp_labels = random_mask_or_replace(data, arange, 16)
        loss, pred = model.step(seqs, segs, seqs_, loss_mask, nsp_labels)
        if t % 100 == 0:
            pred = pred[0].numpy().argmax(axis=1)
            t1 = time.time()
            print(
                "\n\nstep: ", t,
                "| time: %.2f" % (t1 - t0),
                "| loss: %.3f" % loss.numpy(),
                "\n| tgt: ", " ".join([data.i2v[i] for i in seqs[0][:xlen[0].sum()+1]]),
                "\n| prd: ", " ".join([data.i2v[i] for i in pred[:xlen[0].sum()+1]]),
                "\n| tgt word: ", [data.i2v[i] for i in seqs_[0]*loss_mask[0] if i != data.v2i["<PAD>"]],
                "\n| prd word: ", [data.i2v[i] for i in pred*loss_mask[0] if i != data.v2i["<PAD>"]],
                )
            t0 = t1
    os.makedirs("./visual/models/%s" % name, exist_ok=True)
    model.save_weights("./visual/models/%s/model.ckpt" % name)


def export_attention(model, data, name="bert"):
    model.load_weights("./visual/models/%s/model.ckpt" % name)

    # save attention matrix for visualization
    seqs, segs, xlen, nsp_labels = data.sample(32)
    model.call(seqs, segs, False)
    data = {"src": [[data.i2v[i] for i in seqs[j]] for j in range(len(seqs))], "attentions": model.attentions}
    path = "./visual/tmp/%s_attention_matrix.pkl" % name
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "wb") as f:
        pickle.dump(data, f)


if __name__ == "__main__":
    utils.set_soft_gpu(True)
    MODEL_DIM = 256
    N_LAYER = 4
    LEARNING_RATE = 1e-4
    MASK_RATE = 0.15

    d = utils.MRPCData("./MRPC", 2000)
    print("num word: ", d.num_word)
    m = BERT(
        model_dim=MODEL_DIM, max_len=d.max_len, n_layer=N_LAYER, n_head=4, n_vocab=d.num_word,
        lr=LEARNING_RATE, max_seg=d.num_seg, drop_rate=0.2, padding_idx=d.v2i["<PAD>"])
    train(m, d, step=10000, name="bert")
    export_attention(m, d, "bert")

num word:  12880


step:  0 | time: 1.75 | loss: 9.713 
| tgt:  <GO> around <NUM> gmt , tab shares were up <NUM> cents , or <NUM> % , at a $ <NUM> , having earlier set a record high of a $ <NUM> . <SEP> tab shares jumped <NUM> cents , or <NUM> % , to set a record closing high at a $ <NUM> 
| prd:  spaceflight son nominees recovered july hannum vigilante-style 149mph jordanian quattrone july costs intel smooth complied male desert nominees yes smooth marlins nominees shukrijumah hostages wholly quattrone vital allaire fsb realised motors spaceflight exiles temporarily eagan congratulations four we barnett causes causes hannum gained hannum nominees yorkers son desert sampson prohibition star-telegram decline eagan 
| tgt word:  ['around', 'gmt', 'were', 'high', 'tab', 'closing', '$'] 
| prd word:  ['son', 'recovered', '149mph', 'vital', 'temporarily', 'desert', 'decline']


step:  100 | time: 141.80 | loss: 7.807 
| tgt:  <GO> a panel of the 9th us circuit court of appeals upheld califo



step:  1400 | time: 124.82 | loss: 5.885 
| tgt:  <MASK> under the plan , maine would act as a <quote> pharmacy benefit manager <MASK> to lower the <MASK> of prescription drugs . <SEP> <MASK> <MASK> rx , <MASK> state lawmakers approved in <NUM> , maine would act as <MASK> <quote> pharmacy benefit manager <quote> to lower the cost of prescription drugs 
| prd:  <GO> the the , , , , the , a <quote> , <NUM> the the to <NUM> the the of the the <SEP> <SEP> the the the , the of the , in <NUM> , to to the , the <quote> the <NUM> , <quote> to , the to of , , 
| tgt word:  ['<GO>', '<quote>', 'cost', 'under', 'maine', 'which', 'a'] 
| prd word:  ['<GO>', 'the', 'the', 'the', 'the', 'the', 'the']


step:  1500 | time: 125.93 | loss: 6.357 
| tgt:  <GO> china has threatened to execute or jail for life anyone who breaks their quarantine and intentionally spreads the killer sars virus . <MASK> china , haunted <MASK> the spread of sars in its <MASK> countryside , has threatened to <MASK> or jail f



step:  2800 | time: 144.29 | loss: 6.245 
| tgt:  <GO> the state wants to lower wolf <MASK> in approximately a <MASK> area near the village <MASK> mcgrath . <SEP> <MASK> state wants to kill the wolves in approximately a <MASK> area near mcgrath to <MASK> a moose nursery of sorts 
| prd:  <GO> the 's to to to to to in to a to to a the to to to <SEP> <SEP> to to to to to the to in to a to to a a to a a to to of to 
| tgt word:  ['numbers', '1,700-square-mile', 'of', 'the', '1,700-square-mile', 'establish'] 
| prd word:  ['to', 'to', 'to', 'to', 'to', 'a']


step:  2900 | time: 144.98 | loss: 6.368 
| tgt:  <GO> georgia cannot joe to not get funding , <quote> said dr. melinda rowe , chatham county 's health director . <SEP> critically cannot afford to not get heating , <quote> said county dead brothers dr. melinda rowe 
| prd:  <GO> <quote> <quote> , to not to <quote> , <quote> said <quote> <quote> <quote> , , <quote> 's . <SEP> . <SEP> <quote> <quote> <quote> to not , , , <quote> said 



step:  4000 | time: 132.69 | loss: 5.793 
| tgt:  <MASK> <quote> dan brings <MASK> coca-cola <MASK> experience in managing some of <MASK> world 's largest and most familiar brands , <quote> <MASK> said in a statement . <SEP> in a statement , heyer <MASK> , <quote> <MASK> brings to <MASK> enormous experience in managing some of the world 's largest and most familiar brands 
| prd:  <GO> <quote> <quote> <quote> <quote> <quote> the <quote> in <quote> the of 's <quote> 's and and <quote> <quote> <quote> , <quote> <quote> said in a <SEP> <SEP> <SEP> in a <quote> , <quote> <quote> , <quote> <quote> <quote> to <quote> <quote> <quote> in <quote> the of the <quote> 's 's and <quote> <quote> <quote> 
| tgt word:  ['<GO>', 'to', 'enormous', 'the', 'heyer', 'said', 'dan', 'coca-cola'] 
| prd word:  ['<GO>', '<quote>', 'the', "'s", '<quote>', '<quote>', '<quote>', '<quote>']


step:  4100 | time: 125.84 | loss: 5.535 
| tgt:  <GO> <MASK> suits expand exponentially the number of plaintiffs and dam



step:  5200 | time: 127.97 | loss: 5.097 
| tgt:  <GO> at his <MASK> , he will <MASK> reassigned within the district . <SEP> district superintendent <MASK> chester <MASK> told reporters monday that mccrackin will be reassigned within the district 
| prd:  <GO> at his will , he will district district district the district . <SEP> district district district he district told be monday that district will be district district the district 
| tgt word:  ['request', 'be', 'j.', 'floyd'] 
| prd word:  ['will', 'district', 'district', 'district']


step:  5300 | time: 125.13 | loss: 5.584 
| tgt:  <GO> allegiant lamo , <NUM> , had told reporters dale planned to cavalier to the fbi in sacramento friday mutiny but he cause had second thoughts . <SEP> lamo had told reporters he would surrender to the fbi on the federal courthouse steps chi-chi sacramento pleased friday , but he didn 't show up 
| prd:  <GO> he was , <NUM> , had to , the would to he to the would in on on the but he on had <SEP> h



step:  6600 | time: 128.88 | loss: 3.061 
| tgt:  <GO> in nairobi , kenya , the very rev. peter karanja , provost of all saints cathedral , said the u.s. episcopal church <quote> is alienating itself from the anglican communion . <quote> <SEP> the episcopal church ' ' is alienating itself from the anglican communion , ' ' said the very rev. peter karanja , provost of the all saints cathedral , in nairobi 
| prd:  <GO> in ' , , , the very ' ' ' , ' of all <quote> <quote> , said the u.s. ' ' <quote> is communion communion from the communion communion . <quote> <SEP> the communion church ' ' is communion the from the communion communion , ' ' said the of , ' communion , said of the all ' communion , in communion 
| tgt word:  ['rev.', 'saints', 'episcopal', 'from', 'anglican', 'is', ',', "'", 'peter'] 
| prd word:  ["'", '<quote>', "'", 'from', 'communion', 'is', ',', "'", "'"]


step:  6700 | time: 134.99 | loss: 2.424 
| tgt:  <GO> the us federal trade commission has also filed a laws



step:  7900 | time: 126.31 | loss: 3.743 
| tgt:  <GO> this morning , at <MASK> 's new york office , coen revised his expectations downward , saying that <MASK> <MASK> instead rise <NUM> percent to $ <NUM> billion . <MASK> speaking to <MASK> <MASK> a <MASK> york news conference , universal mccann 's coen <MASK> that total u.s. ad spending will rise <NUM> <MASK> to $ <NUM> billion this year 
| prd:  <GO> this stock , at will 's new york of , new $ his billion billion , will that will will $ $ <NUM> percent to $ <NUM> billion . <SEP> to to new at a percent york news 's , $ billion 's billion billion that billion u.s. percent billion will billion <NUM> , to $ <NUM> billion this year 
| tgt word:  ['um', 'spending', 'would', '<SEP>', 'reporters', 'at', 'new', 'projected', 'percent'] 
| prd word:  ['will', 'will', 'will', '<SEP>', 'new', 'at', 'percent', 'billion', ',']


step:  8000 | time: 129.83 | loss: 3.708 
| tgt:  <GO> the dow jones industrial average .dji jumped <NUM> percent , wh



step:  9200 | time: 126.74 | loss: 3.455 
| tgt:  <MASK> the street-racing <MASK> <quote> <NUM> fast <NUM> furious <quote> won <MASK> pole position at the box <MASK> , taking in an estimated $ <NUM> million in its <MASK> weekend . <SEP> the pg-13 sequel <quote> <NUM> <MASK> <MASK> furious <quote> raked in <MASK> estimated $ <NUM> million during its opening weekend , jumping over last <MASK> 's catch , <quote> finding nemo . 
| prd:  <GO> the dollar fast <quote> <NUM> fast <NUM> fast <quote> i the weekend weekend at the weekend , , taking in an estimated $ <NUM> million in its weekend weekend . <SEP> the pg-13 weekend <quote> <NUM> fast <NUM> weekend <quote> fast in the estimated $ <NUM> million during its weekend weekend , weekend over last weekend 's weekend , <quote> <quote> fast . 
| tgt word:  ['<GO>', 'sequel', 'the', 'office', 'opening', 'fast', '<NUM>', 'an', 'weekend'] 
| prd word:  ['<GO>', 'fast', 'the', ',', 'weekend', 'fast', '<NUM>', 'the', 'weekend']


step:  9300 | tim