In [1]:
pip install sentencepiece

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentencepiece
  Downloading sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m26.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.97


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
import sentencepiece as spm
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from random import random, randrange, randint, shuffle, choice
from typing import Optional
import torch.optim as optim
import torch.nn.functional as F
import os
from tqdm import tqdm, tqdm_notebook, trange
import json
import numpy as np

In [6]:
corpus = "/content/drive/MyDrive/WikiQA.txt"
prefix = "WikiQA"
vocab_size = 12000
spm.SentencePieceTrainer.train(
    f"--input={corpus} --model_prefix={prefix} --vocab_size={vocab_size + 7}" + 
    " --model_type=bpe" +
    " --max_sentence_length=999999" + # 문장 최대 길이
    " --pad_id=0 --pad_piece=[PAD]" + # pad (0)
    " --unk_id=1 --unk_piece=[UNK]" + # unknown (1)
    " --bos_id=2 --bos_piece=[BOS]" + # begin of sequence (2)
    " --eos_id=3 --eos_piece=[EOS]" + # end of sequence (3)
    " --user_defined_symbols=[SEP],[CLS],[MASK]") # 사용자 정의 토큰

In [7]:
vocab_file = "WikiQA.model"
vocab = spm.SentencePieceProcessor()
vocab.load(vocab_file)

lines = [
  'i hydro you , i'
]
for line in lines:
    pieces = vocab.encode_as_pieces(line)
    ids = vocab.encode_as_ids(line)
    a = vocab.Decode(ids)
    print(line)
    print(pieces)
    print(ids)
    print(a)
    print()
vocab.decode(1737)

i hydro you , i
['▁i', '▁hydro', '▁you', '▁,', '▁i']
[537, 1866, 1523, 32, 537]
i hydro you , i



'white'

In [8]:
  class Config(dict): 
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__

    @classmethod
    def load(cls, file):
        with open(file, 'r') as f:
            config = json.loads(f.read())
            return Config(config)

In [9]:
config = Config({
    "n_enc_vocab": len(vocab),
    "n_enc_seq": 256,
    "n_seg_type": 2,
    "n_layer": 12,
    "d_hidn": 768,
    "i_pad": 0,
    "d_ff": 1024,
    "n_head": 16,
    "d_head": 48,
    "dropout": 0.1,
    "layer_norm_epsilon": 1e-12
})
print(config)

{'n_enc_vocab': 12007, 'n_enc_seq': 256, 'n_seg_type': 2, 'n_layer': 12, 'd_hidn': 768, 'i_pad': 0, 'd_ff': 1024, 'n_head': 16, 'd_head': 48, 'dropout': 0.1, 'layer_norm_epsilon': 1e-12}


In [10]:
""" attention pad mask """
def get_attn_pad_mask(seq_q, seq_k, i_pad):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(i_pad).unsqueeze(1).expand(batch_size, len_q, len_k)  # 
    return pad_attn_mask

""" scale dot product attention """
class ScaledDotProductAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.dropout = nn.Dropout(config.dropout)
        self.scale = 1 / (self.config.d_head ** 0.5)
    
    def forward(self, Q, K, V, attn_mask):
        # (bs, n_head, n_q_seq, n_k_seq)
        scores = torch.matmul(Q, K.transpose(-1, -2)).mul_(self.scale)
        scores.masked_fill_(attn_mask, -1e9)
        # (bs, n_head, n_q_seq, n_k_seq)
        attn_prob = nn.Softmax(dim=-1)(scores)
        attn_prob = self.dropout(attn_prob)
        # (bs, n_head, n_q_seq, d_v)
        context = torch.matmul(attn_prob, V)
        # (bs, n_head, n_q_seq, d_v), (bs, n_head, n_q_seq, n_v_seq)
        return context, attn_prob


""" multi head attention """
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.W_Q = nn.Linear(self.config.d_hidn, self.config.n_head * self.config.d_head)
        self.W_K = nn.Linear(self.config.d_hidn, self.config.n_head * self.config.d_head)
        self.W_V = nn.Linear(self.config.d_hidn, self.config.n_head * self.config.d_head)
        self.scaled_dot_attn = ScaledDotProductAttention(self.config)
        self.linear = nn.Linear(self.config.n_head * self.config.d_head, self.config.d_hidn)
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, Q, K, V, attn_mask):
        batch_size = Q.size(0)
        # (bs, n_head, n_q_seq, d_head)
        q_s = self.W_Q(Q).view(batch_size, -1, self.config.n_head, self.config.d_head).transpose(1,2)
        # (bs, n_head, n_k_seq, d_head)
        k_s = self.W_K(K).view(batch_size, -1, self.config.n_head, self.config.d_head).transpose(1,2)
        # (bs, n_head, n_v_seq, d_head)
        v_s = self.W_V(V).view(batch_size, -1, self.config.n_head, self.config.d_head).transpose(1,2)

        # (bs, n_head, n_q_seq, n_k_seq)
        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.config.n_head, 1, 1)

        # (bs, n_head, n_q_seq, d_head), (bs, n_head, n_q_seq, n_k_seq)
        context, attn_prob = self.scaled_dot_attn(q_s, k_s, v_s, attn_mask)
        # (bs, n_head, n_q_seq, h_head * d_head)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.n_head * self.config.d_head)
        # (bs, n_head, n_q_seq, e_embd)
        output = self.linear(context)
        output = self.dropout(output)
        # (bs, n_q_seq, d_hidn), (bs, n_head, n_q_seq, n_k_seq)
        return output, attn_prob


""" feed forward """
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.conv1 = nn.Conv1d(in_channels=self.config.d_hidn, out_channels=self.config.d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=self.config.d_ff, out_channels=self.config.d_hidn, kernel_size=1)
        # self.w_1 = nn.Linear(self.config.d_hidn, self.config.d_ff)
        # self.w_2 = nn.Linear(self.config.d_ff, self.config.d_hidn)
        self.active = F.gelu
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, inputs):
        # (bs, d_ff, n_seq)
        output = self.active(self.conv1(inputs.transpose(1, 2)))
        # (bs, n_seq, d_hidn)
        output = self.conv2(output).transpose(1, 2)
        output = self.dropout(output)
        # (bs, n_seq, d_hidn)
        return output
        # return self.w_2(self.dropout(self.active(self.w_1(inputs))))

In [11]:
""" encoder layer """
class EncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.self_attn = MultiHeadAttention(self.config)
        self.layer_norm1 = nn.LayerNorm(self.config.d_hidn, eps=self.config.layer_norm_epsilon)
        self.pos_ffn = PoswiseFeedForwardNet(self.config)
        self.layer_norm2 = nn.LayerNorm(self.config.d_hidn, eps=self.config.layer_norm_epsilon)
    
    def forward(self, inputs, attn_mask):
        # (bs, n_enc_seq, d_hidn), (bs, n_head, n_enc_seq, n_enc_seq)
        att_outputs, attn_prob = self.self_attn(inputs, inputs, inputs, attn_mask)
        att_outputs = self.layer_norm1(inputs + att_outputs)
        # (bs, n_enc_seq, d_hidn)
        ffn_outputs = self.pos_ffn(att_outputs)
        ffn_outputs = self.layer_norm2(ffn_outputs + att_outputs)
        # (bs, n_enc_seq, d_hidn), (bs, n_head, n_enc_seq, n_enc_seq)
        return ffn_outputs, attn_prob
     

In [12]:
""" encoder """
class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.enc_emb = nn.Embedding(self.config.n_enc_vocab, self.config.d_hidn)
        self.pos_emb = nn.Embedding(self.config.n_enc_seq + 1, self.config.d_hidn)
        self.seg_emb = nn.Embedding(self.config.n_seg_type, self.config.d_hidn)

        self.layers = nn.ModuleList([EncoderLayer(self.config) for _ in range(self.config.n_layer)])
    
    def forward(self, inputs, segments):
        positions = torch.arange(inputs.size(1), device=inputs.device, dtype=inputs.dtype).expand(inputs.size(0), inputs.size(1)).contiguous() + 1
        pos_mask = inputs.eq(self.config.i_pad)
        positions.masked_fill_(pos_mask, 0)

        # (bs, n_enc_seq, d_hidn)
        outputs = self.enc_emb(inputs) + self.pos_emb(positions)  + self.seg_emb(segments)

        # (bs, n_enc_seq, n_enc_seq)
        attn_mask = get_attn_pad_mask(inputs, inputs, self.config.i_pad)

        attn_probs = []
        for layer in self.layers:
            # (bs, n_enc_seq, d_hidn), (bs, n_head, n_enc_seq, n_enc_seq)
            outputs, attn_prob = layer(outputs, attn_mask)
            attn_probs.append(attn_prob)
        # (bs, n_enc_seq, d_hidn), [(bs, n_head, n_enc_seq, n_enc_seq)]
        return outputs, attn_probs
     

In [13]:
class BERT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.encoder = Encoder(self.config)

        self.linear = nn.Linear(config.d_hidn, config.d_hidn)
        self.activation = torch.tanh
    
    def forward(self, inputs, segments):
        # (bs, n_seq, d_hidn), [(bs, n_head, n_enc_seq, n_enc_seq)]
        outputs, self_attn_probs = self.encoder(inputs, segments)
        # (bs, d_hidn)
        outputs_cls = outputs[:, 0].contiguous()
        outputs_cls = self.linear(outputs_cls)
        outputs_cls = self.activation(outputs_cls)
        # (bs, n_enc_seq, n_enc_vocab), (bs, d_hidn), [(bs, n_head, n_enc_seq, n_enc_seq)]
        return outputs, outputs_cls, self_attn_probs
    
    def save(self, epoch, loss, path):
        torch.save({
            "epoch": epoch,
            "loss": loss,
            "state_dict": self.state_dict()
        }, path)
    
    def load(self, path):
        save = torch.load(path)
        self.load_state_dict(save["state_dict"])
        return save["epoch"], save["loss"]

In [14]:
""" BERT pretrain """
class BERTPretrain(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.bert = BERT(self.config)
        # classfier
        self.projection_cls = nn.Linear(self.config.d_hidn, 2, bias=False)
        # lm
        self.projection_lm = nn.Linear(self.config.d_hidn, self.config.n_enc_vocab, bias=False)
        self.projection_lm.weight = self.bert.encoder.enc_emb.weight
        self.softmax = nn.LogSoftmax(dim=-1)
    
    def forward(self, inputs, segments):
        # (bs, n_enc_seq, d_hidn), (bs, d_hidn), [(bs, n_head, n_enc_seq, n_enc_seq)]
        outputs, outputs_cls, attn_probs = self.bert(inputs, segments)
        # (bs, 2)
        logits_cls = self.projection_cls(outputs_cls)
        # logits_cls = self.softmax(logits_cls)
        # (bs, n_enc_seq, n_enc_vocab)
        logits_lm = self.projection_lm(outputs)
        # logits_lm = self.softmax(logits_lm)
        # (bs, n_enc_vocab), (bs, n_enc_seq, n_enc_vocab), [(bs, n_head, n_enc_seq, n_enc_seq)]
        return logits_cls, logits_lm, attn_probs

In [15]:
""" 마스크 생성 """
def create_pretrain_mask(tokens, mask_cnt, vocab_list):
    cand_idx = []
    for (i, token) in enumerate(tokens):
        if token == "[CLS]" or token == "[SEP]":
            continue
        if 0 < len(cand_idx) and not token.startswith(u"\u2581"):
            cand_idx[-1].append(i)
        else:
            cand_idx.append([i])
    shuffle(cand_idx)

    mask_lms = []
    for index_set in cand_idx:
        if len(mask_lms) >= mask_cnt:
            break
        if len(mask_lms) + len(index_set) > mask_cnt:
            continue
        for index in index_set:
            masked_token = None
            if random() < 0.8: # 80% replace with [MASK]
                masked_token = "[MASK]"
            else:
                if random() < 0.5: # 10% keep original
                    masked_token = tokens[index]
                else: # 10% random word
                    masked_token = choice(vocab_list)
            mask_lms.append({"index": index, "label": tokens[index]})
            tokens[index] = masked_token
    mask_lms = sorted(mask_lms, key=lambda x: x["index"])
    mask_idx = [p["index"] for p in mask_lms]
    mask_label = [p["label"] for p in mask_lms]

    return tokens, mask_idx, mask_label

In [16]:
def trim_tokens(tokens_a, tokens_b, max_seq):
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_seq:
            break

        if len(tokens_a) > len(tokens_b):
            del tokens_a[0]
        else:
            tokens_b.pop()

In [17]:
""" doc별 pretrain 데이터 생성 """
def create_pretrain_instances(docs, doc_idx, doc, n_seq, mask_prob, vocab_list):
    # for CLS], [SEP], [SEP]
    max_seq = n_seq - 3
    tgt_seq = max_seq
    
    instances = []
    current_chunk = []
    current_length = 0
    a_end=1
    for i in range(len(doc)):
        current_chunk.append(doc[i]) # line
        current_length += len(doc[i])
    for i in range(len(current_chunk)):
        if current_chunk[i] =='▁,':
            a_end=i
    tokens_a = doc[:a_end]
    tokens_b = doc[a_end+1: ]
   
    is_next = 1
    
    if random() < 0.5:
        is_next = 0
        tokens_b = []
        tokens_b_len = tgt_seq - len(tokens_a)
        random_doc_idx = 0
        while True:
            random_doc_idx = randrange(0, len(docs))
            if random_doc_idx != doc_idx:
              break
        random_doc = docs[random_doc_idx]

        random_start = randrange(0, len(random_doc))
        for j in range(random_start, len(random_doc)):
            tokens_b.append(random_doc[j])
    
    
    
    trim_tokens(tokens_a, tokens_b, max_seq)
    if len(tokens_b) > 0 and len(tokens_a) > 0:
        tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"]
        segment = [0] * (len(tokens_a) + 2) + [1] * (len(tokens_b) + 1)

        tokens, mask_idx, mask_label = create_pretrain_mask(tokens, int((len(tokens) - 3) * mask_prob), vocab_list)

        instance = {
            "tokens": tokens,
            "segment": segment,
            "is_next": is_next,
            "mask_idx": mask_idx,
            "mask_label": mask_label
        }
        instances.append(instance)
        current_chunk = []
        current_length = 0

        # if i == len(doc) - 1 or current_length >= tgt_seq:
        #     if 0 < len(current_chunk):                    
        #         a_end = 1
        #         if 1 < len(current_chunk):
        #             a_end = randrange(1, len(current_chunk))
        #         tokens_a = []
        #         for j in range(a_end):
        #             tokens_a.append(current_chunk[j])
                
        #         tokens_b = []
        #         if len(current_chunk) == 1 or random() < 0.5:
        #             is_next = 0
        #             tokens_b_len = tgt_seq - len(tokens_a)
        #             random_doc_idx = doc_idx
        #             while doc_idx == random_doc_idx:
        #                 random_doc_idx = randrange(0, len(docs))
        #             random_doc = docs[random_doc_idx]

        #             random_start = randrange(0, len(random_doc))
        #             for j in range(random_start, len(random_doc)):
        #                 tokens_b.append(random_doc[j])
        #         else:
        #             is_next = 1
        #             for j in range(a_end, len(current_chunk)):
        #                 tokens_b.append(current_chunk[j])

        #         trim_tokens(tokens_a, tokens_b, max_seq)
        #         assert 0 < len(tokens_a)
        #         assert 0 < len(tokens_b)

        #         tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"]
        #         segment = [0] * (len(tokens_a) + 2) + [1] * (len(tokens_b) + 1)

        #         tokens, mask_idx, mask_label = create_pretrain_mask(tokens, int((len(tokens) - 3) * mask_prob), vocab_list)

        #         instance = {
        #             "tokens": tokens,
        #             "segment": segment,
        #             "is_next": is_next,
        #             "mask_idx": mask_idx,
        #             "mask_label": mask_label
        #         }
        #         instances.append(instance)
        #     current_chunk = []
        #     current_length = 0
    return instances  

In [18]:
def make_pretrain_data(vocab, in_file, out_file, count, n_seq, mask_prob):
    vocab_list = []
    for id in range(vocab.get_piece_size()):
        if not vocab.is_unknown(id):
            vocab_list.append(vocab.id_to_piece(id))

    line_cnt = 0
    with open(in_file, "r", encoding='UTF8') as in_f:
        for line in in_f:
            line_cnt += 1
    docs = []
    with open(in_file, "r", encoding='UTF8') as f:
        doc = []
        with tqdm(total=line_cnt, desc=f"Loading") as pbar:
            for i, line in enumerate(f):
                line = line.strip()
                if line == "":
                    if 0 < len(doc):
                        docs.append(doc)
                        doc = []
                        # 메모리 사용량을 줄이기 위해 100,000개만 처리 함
                        if 100000 < len(docs): break
                else:
                    pieces = vocab.encode_as_pieces(line)
                    if 0 < len(pieces):
                        doc.append(pieces)
                pbar.update(1)
        if doc:
            docs.append(doc)
    docs = sum(docs,[])
    for index in range(count):
        output = out_file.format(index)
        # if os.path.isfile(output): continue
        
        with open(output, "w", encoding='UTF8') as out_f:
            with tqdm(total=len(docs), desc=f"Making") as pbar:
                for i, doc in enumerate(docs):
                    instances = create_pretrain_instances(docs, i, doc, n_seq, mask_prob, vocab_list)
                    for instance in instances:
                        print(instance, file=out_f)
                        # out_f.write(instance)
                        # out_f.write("\n")
                    pbar.update(1)

In [19]:
in_file = "/content/drive/MyDrive/WikiQA.txt"
out_file = "WikiQA_{}.json"
count = 1
n_seq = 256
mask_prob = 0.15

make_pretrain_data(vocab, in_file, out_file, count, n_seq, mask_prob)

Loading: 100%|██████████| 63282/63282 [00:06<00:00, 9179.91it/s]
Making: 100%|██████████| 63282/63282 [00:03<00:00, 18813.62it/s]


In [20]:
class PretrainDataSet(torch.utils.data.Dataset):
    def __init__(self, vocab, infile):
        self.vocab = vocab
        self.labels_cls = []
        self.labels_lm = []
        self.sentences = []
        self.segments = []
    
        line_cnt = 0
        with open(infile, "r", encoding="UTF8") as f:
            for line in f:
                line_cnt += 1

        with open(infile, "r",encoding="UTF8") as f:
            for i, line in enumerate(tqdm(f, total=line_cnt, desc=f"Loading {infile}", unit=" lines")):
                line = json.dumps(line)
                instance = eval(json.loads(line))
                self.labels_cls.append(instance['is_next'])
                sentences = [vocab.piece_to_id(p) for p in instance["tokens"]]
                self.sentences.append(sentences)
                self.segments.append(instance["segment"])
                mask_idx = np.array(instance["mask_idx"], dtype=np.int64)
                mask_label = np.array([vocab.piece_to_id(p) for p in instance["mask_label"]], dtype=np.int64)
                label_lm = np.full(len(sentences), dtype=np.int64, fill_value=-1)
                label_lm[mask_idx] = mask_label
                self.labels_lm.append(label_lm)
    
    def __len__(self):
        assert len(self.labels_cls) == len(self.labels_lm)
        assert len(self.labels_cls) == len(self.sentences)
        assert len(self.labels_cls) == len(self.segments)
        return len(self.labels_cls)
    
    def __getitem__(self, item):
        return (torch.tensor(self.labels_cls[item]),
                torch.tensor(self.labels_lm[item]),
                torch.tensor(self.sentences[item]),
                torch.tensor(self.segments[item]))

In [21]:
""" pretrain data collate_fn """
def pretrin_collate_fn(inputs):
    labels_cls, labels_lm, inputs, segments = list(zip(*inputs))

    labels_lm = torch.nn.utils.rnn.pad_sequence(labels_lm, batch_first=True, padding_value=0)
    inputs = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=0)
    segments = torch.nn.utils.rnn.pad_sequence(segments, batch_first=True, padding_value=0)
   
    batch = [
        torch.stack(labels_cls, dim=0),
        labels_lm,
        inputs,
        segments
    ]
    return batch

In [22]:
""" pretrain 데이터 로더 """
batch_size = 64
dataset = PretrainDataSet(vocab, "WikiQA_0.json")
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=pretrin_collate_fn)

Loading WikiQA_0.json: 100%|██████████| 63281/63281 [00:12<00:00, 5243.39 lines/s]


In [23]:
# def train_epoch_nsp(config, epoch, model, criterion_cls, optimizer, train_loader):
#     losses = []
#     model.train()
#     with tqdm(total=len(train_loader), desc=f"Train({epoch})") as pbar:
#         for i, value in enumerate(train_loader):
#             labels_cls, labels_lm, inputs, segments = map(lambda v: v.to(config.device), value)
#             optimizer.zero_grad()
#             outputs = model(inputs, segments)
#             logits_cls = outputs[0]
        
#             loss_cls = criterion_cls(logits_cls, labels_cls)
            
            

#             losses.append(loss_cls)
    
#             loss_cls.backward()
#             optimizer.step()

#             pbar.update(1)
#             pbar.set_postfix_str(f"Loss: {loss_cls:.3f} ({np.mean(losses):.3f})")
#     return np.mean(losses)

In [24]:
def train_epoch_mlm(config, epoch, model, criterion_lm, optimizer, train_loader):
    losses = []
    model.train()
    with tqdm(total=len(train_loader), desc=f"Train({epoch})") as pbar:
        for i, value in enumerate(train_loader):
            labels_cls, labels_lm, inputs, segments = map(lambda v: v.to(config.device), value)
            optimizer.zero_grad()
            outputs = model(inputs, segments)
            logits_lm =  outputs[1]
        
            loss_lm = criterion_lm(logits_lm.view(-1, logits_lm.size(2)), labels_lm.view(-1))

            loss_val = loss_lm.item()
            losses.append(loss_val)
    
            loss_lm.backward()
            optimizer.step()

            pbar.update(1)
            pbar.set_postfix_str(f"Loss: {loss_val:.3f} ({np.mean(losses):.3f})")
    return np.mean(losses)

In [33]:
config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(config)

learning_rate = 1e-4
betas=(0.9, 0.999)
weight_decay = 0.01
n_epoch = 50

{'n_enc_vocab': 12007, 'n_enc_seq': 256, 'n_seg_type': 2, 'n_layer': 12, 'd_hidn': 768, 'i_pad': 0, 'd_ff': 1024, 'n_head': 16, 'd_head': 48, 'dropout': 0.1, 'layer_norm_epsilon': 1e-12, 'device': device(type='cuda')}


In [34]:
model = BERTPretrain(config)

save_pretrain = "save_bert_pretrain.pth"
best_epoch, best_loss = 0, 0
if os.path.isfile(save_pretrain):
    best_epoch, best_loss = model.bert.load(save_pretrain)
    print(f"load pretrain from: {save_pretrain}, epoch={best_epoch}, loss={best_loss}")
    best_epoch += 1

model.to(config.device)

criterion_lm = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='mean')
criterion_cls = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# losses = []
# offset = best_epoch
# for step in range(n_epoch):
#     epoch = step + offset
#     loss= train_epoch_nsp(config, epoch, model,criterion_cls optimizer, train_loader)
#     losses.append(loss)
#     model.bert.save(epoch, loss, save_pretrain)

losses = []
offset = best_epoch
for step in range(n_epoch):
    epoch = step + offset
    loss= train_epoch_mlm(config, epoch, model, criterion_lm, optimizer, train_loader)
    losses.append(loss)
    model.bert.save(epoch, loss, save_pretrain)

load pretrain from: save_bert_pretrain.pth, epoch=49, loss=0.6201690233061118


Train(50): 100%|██████████| 989/989 [02:21<00:00,  6.97it/s, Loss: 0.410 (0.621)]


#### Mask 예측 Test

In [35]:
test_inputs = vocab.encode('how long is the term for federal[MASK] , In the United States  the title of federal judge usually means a judge appointed by the President of the United States and confirmed by the United States Senate pursuant to the Appointments Clause in Article II of the United States Constitution')
test_inputs = torch.tensor([test_inputs]).to(config.device)
test_segments = torch.zeros(len(test_inputs), dtype=int).to(config.device)
model.eval()
out = model(test_inputs, test_segments)
[vocab.decode(i) for i in out[1].argmax(2).cpu().detach().tolist()]

['is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is']

In [36]:
vocab.encode('[MASK]')
vocab.decode(0)

''

In [37]:
test_segments = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0,0,0,0,0, 1, 1, 1, 1, 1, 1,1,1]
test_segments = torch.tensor(test_segments).to(config.device)

In [38]:
test_inputs.shape, test_segments.shape

(torch.Size([1, 55]), torch.Size([30]))

In [57]:
model.eval()
out = model(test_inputs, test_segments)
[vocab.decode(i) for i in out[1].argmax(2).cpu().detach().tolist()]

['is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is',
 'is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is',
 'is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is',
 'is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is',
 'is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is',
 'is is is is is is is is is is is is is is is is is is is is is 

In [40]:
test_sentence = '[CLS] what are your hobbies [SEP] i have been watching and playing [MASK] more than ten years , so that both are my hobbies'
mask_idx1 = 13
sentence1 = 'what are your hobbies'
sentence2 = 'i have been watching and playing'
sentence3 = 'more than ten years , so that both are my hobbies'

In [41]:
test_input1 = [5]
for i in sentence1.split(' '):
    test_input1 += vocab.Encode(i)
test_input1 += [4]
for i in sentence2.split(' '):
    test_input1 += vocab.Encode(i)
test_input1 += [6]
for i in sentence3.split(' '):
    test_input1 += vocab.Encode(i)
test_input1 += [4]
test_input1 += ([0] * (61 - len(test_input1)))
test_input1 = torch.tensor(test_input1)

In [42]:
test_input1

tensor([    5,    51,    97,  3518,    46,   394, 11967,   121,     4,   537,
          258,   366,  2966,  6386,    42,  4129,     6,   445,   507,  2109,
          627,    32,   707,   160,   787,    97,  1520,    46,   394, 11967,
          121,     4,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0])

In [43]:
test_segment1 = torch.tensor([0]*6 + [1]*18 + [0]*37)

In [44]:
test_segment1

tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [45]:
test_sentence = '[CLS] what is your greatest strength [SEP] i have been working hard in the biology field for about thirty years , that is why i can [MASK] saying that diligence is my greatest strength'
mask_idx2 = 26
sentence1 = 'what is your greatest strength'
sentence2 = 'i have been working hard in the biology field for about thirty years , that is why i can'
sentence3 = 'saying that diligence is my greatest strength'

In [46]:
test_input2 = [5]
for i in sentence1.split(' '):
    test_input2 += vocab.Encode(i)
test_input2 += [4]
for i in sentence2.split(' '):
    test_input2 += vocab.Encode(i)
test_input2 += [6]
for i in sentence3.split(' '):
    test_input2 += vocab.Encode(i)
test_input2 += ([0] * (61 - len(test_input2)))
test_input2 = torch.tensor(test_input2)

In [47]:
test_input2

tensor([    5,    51,    41,  3518,  2588,  8758,     4,   537,   258,   366,
         3566,  2512,    28,    11,  8320,  1646,    83,   597,  8243,   627,
           32,   160,    41,  8070,   537,   307,     6, 11778,   160,    38,
           55,  3505,    41,  1520,  2588,  8758,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0])

In [48]:
test_segment2 = torch.tensor([0]*7 + [1]*27 + [0]*27)

In [49]:
test_segment2

tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [50]:
test_inputs = torch.stack([test_input1,test_input2]*16,0)

In [51]:
test_inputs

tensor([[ 5, 51, 97,  ...,  0,  0,  0],
        [ 5, 51, 41,  ...,  0,  0,  0],
        [ 5, 51, 97,  ...,  0,  0,  0],
        ...,
        [ 5, 51, 41,  ...,  0,  0,  0],
        [ 5, 51, 97,  ...,  0,  0,  0],
        [ 5, 51, 41,  ...,  0,  0,  0]])

In [52]:
test_segments = torch.stack([test_segment1,test_segment2]*16,0)

In [53]:
test_inputs = test_inputs.to(config.device)
test_segments = test_segments.to(config.device)

In [54]:
model.eval()
# model = model.to(config.device)
test_out = model(test_inputs,test_segments)

In [55]:
num = 1
test_sentence = '[CLS] what are your hobbies [SEP] i have been watching and playing [MASK] more than ten years , so that both are my hobbies'

print('test_sentence1')
print('MASK 예측 결과')
for j in [vocab.decode(i) for i in torch.topk(test_out[1][0][13],k=5)[1].cpu().detach().tolist()]:
    print(f"{num}. {j}")
    num += 1

test_sentence1
MASK 예측 결과
1. is
2. what
3. who
4. when
5. where


In [56]:
num = 1
print('test_sentence2')
print('MASK 예측 결과')
for j in [vocab.decode(i) for i in torch.topk(test_out[1][1][mask_idx2],k=5)[1].cpu().detach().tolist()]:
    print(f"{num}. {j}")
    num += 1

test_sentence2
MASK 예측 결과
1. is
2. what
3. who
4. when
5. where
