In [6]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm, trange

import csv
import random

from bert_util import *
%run ./BERT.ipynb

Hey False
Hidden state has shape of : (1, 128, 768)
intermediate output:  (1, 128, 3072)
Hidden state has shape of : (1, 128, 3072)


In [None]:
data_file = open('data/imdb_train.csv')

csv_reader = csv.reader(data_file)

dataset = []
for line in csv_reader:
    dataset.append(line)

In [None]:
dataset[0]

['positive',
 "one of the other reviewers has mentioned that after watching just 1 oz episode you'll be hooked.",
 'they are right, as this is exactly what happened with me.',
 'the first thing that struck me about oz was its brutality and unflinching scenes of violence, which set in right from the word go.',
 'trust me, this is not a show for the faint hearted or timid.',
 'this show pulls no punches with regards to drugs, sex or violence.',
 'its is hardcore, in the classic use of the word.',
 'it is called oz as that is the nickname given to the oswald maximum security state penitentary.',
 'it focuses mainly on emerald city, an experimental section of the prison where all the cells have glass fronts and face inwards, so privacy is not high on the agenda.',
 'em city is home to many..aryans, muslims, gangstas, latinos, christians, italians, irish and more....so scuffles, death stares, dodgy dealings and shady agreements are never far away.',
 "i would say the main appeal of the show

In [None]:
# 첫번째에 있는 sentiment label은 pretraining에서 사용하지 않으므로 제거합니다.

dataset = [data[1:] for data in dataset]

In [None]:
dataset[0]

["one of the other reviewers has mentioned that after watching just 1 oz episode you'll be hooked.",
 'they are right, as this is exactly what happened with me.',
 'the first thing that struck me about oz was its brutality and unflinching scenes of violence, which set in right from the word go.',
 'trust me, this is not a show for the faint hearted or timid.',
 'this show pulls no punches with regards to drugs, sex or violence.',
 'its is hardcore, in the classic use of the word.',
 'it is called oz as that is the nickname given to the oswald maximum security state penitentary.',
 'it focuses mainly on emerald city, an experimental section of the prison where all the cells have glass fronts and face inwards, so privacy is not high on the agenda.',
 'em city is home to many..aryans, muslims, gangstas, latinos, christians, italians, irish and more....so scuffles, death stares, dodgy dealings and shady agreements are never far away.',
 "i would say the main appeal of the show is due to th

In [None]:
vocab_file = open('vocab.txt') # pretrained model에서 사전에 정의된 vocabulary 그대로 사용

vocab = vocab_file.readlines()
vocab = [word.strip() for word in vocab]

In [None]:
word2id = {vocab[idx]:idx for idx in range(len(vocab)) }

In [None]:
wp = WordpieceTokenizer(vocab, '[UNK]')

In [None]:
tokenized_dataset = []


for i in trange(len(dataset[:1000])):
    tokenized = [wp.tokenize(sent) for sent in dataset[i]]
    tokenized_dataset.append(tokenized)

100%|██████████| 1000/1000 [01:22<00:00, 12.07it/s]


In [None]:
tokenized_dataset[0][0]

['one',
 'of',
 'the',
 'other',
 'reviewers',
 'has',
 'mentioned',
 'that',
 'after',
 'watching',
 'just',
 '1',
 'oz',
 'episode',
 'you',
 "##'",
 '##ll',
 'be',
 'hooked',
 '##.']

In [None]:
dataset[0][0]

"one of the other reviewers has mentioned that after watching just 1 oz episode you'll be hooked."

In [None]:
indexed_dataset = [[[word2id[word] for word in sent] for sent in par] for par in tokenized_dataset]

In [None]:
indexed_dataset[0][0]

[2028,
 1997,
 1996,
 2060,
 15814,
 2038,
 3855,
 2008,
 2044,
 3666,
 2074,
 1015,
 11472,
 2792,
 2017,
 29618,
 3363,
 2022,
 13322,
 29625]

In [None]:
class PretrainDataset():
    def __init__(self, dataset, vocab):
        """ Maked Language Modeling & Next Sentence Prediction dataset initializer
        Use below attributes when implementing the dataset

        Attributes:
        dataset -- Paragraph dataset to make a MLM & NSP sample
        """
        self.dataset = dataset
        self.vocab = vocab
        self.CLS = word2id['[CLS]']
        self.SEP = word2id['[SEP]']
        self.MSK = word2id['[MASK]']
        self.PARA_NUM = len(self.dataset)
        self.par_len = [len(par) for par in dataset]
        self.max_len = 128

        
        #self.special_tokens = [CLS, SEP, MSK]
    @property
    def token_num(self):
        return len(self.vocab)

    
    def masking(self, sen1, sen2):
        """
        Inputs:
        sen1 -- 인덱스로 된 첫번째 문장, List(int)
        sen2 -- 인덱스로 된 두번째 문장, List(int)
        
        Output:
        MLM_sentences -- sen1+sen2의 전체 시퀀스에서 15%의 token을 선택하여, 그 중 80%는 MSK token으로 대체하고,
                         10%는 랜덤 token으로 대체, 나머지 10%는 원래 token을 그대로 사용.
                         len(MLM_sentences) = len(sen1)+len(sen2)+3 
        
        """
        
        MLM_sentences = sen1+sen2
        TOKEN_NUM = self.token_num
        
        masking_id = random.sample(range(len(MLM_sentences)), int(len(MLM_sentences)*0.15))
        random.shuffle(masking_id)
        
        mask = [True if i in masking_id else False for i in range(len(MLM_sentences))]
        MLM_mask = [False] + mask[:len(sen1)] + [False] +  mask[len(sen1):] + [False]

        for i in range(len(masking_id)):
            if i < len(masking_id)*0.8:
                MLM_sentences[masking_id[i]] = self.MSK
            elif 0.8*len(masking_id) <= i and i < 0.9*len(masking_id):
                rand_id = random.randrange(TOKEN_NUM-4)
                spc_tks = [self.CLS, self.SEP, self.MSK, MLM_sentences[masking_id[i]]]                               
                rand_id = rand_id if rand_id not in spc_tks else TOKEN_NUM - spc_tks.index(rand_id)-1
                MLM_sentences[masking_id[i]] = rand_id
        MLM_sentences = [self.CLS] + MLM_sentences[:len(sen1)] + [self.SEP] + MLM_sentences[len(sen1):] + [self.SEP]

        return MLM_sentences, MLM_mask
    
    def positive_sampling(self):
        
        valid_par = np.where(np.array(self.par_len)>1)[0]
                             
        par_id1 = random.choice(valid_par)#(PARA_NUM)
        par1 = self.dataset[par_id1]
                             
        sen_id1 = random.randrange(len(par1)-1)#randrange(len(par1)-1)
                             
        sen1 = par1[sen_id1]
        sen2 = par1[sen_id1+1]
        
        return sen1, sen2
    
    def negative_sampling(self):
        
        par_id1, par_id2 = random.sample(range(self.PARA_NUM), 2)
                             
        par1 = self.dataset[par_id1]
        par2 = self.dataset[par_id2]
                             
        sen_id1 = random.randrange(len(par1))
        sen_id2 = random.randrange(len(par2))

        sen1 = par1[sen_id1]
        sen2 = par2[sen_id2]
        
        return sen1, sen2
                
    
    def __iter__(self):
        """ Masked Language Modeling & Next Sentence Prediction dataset
        Sample two sentences from the dataset, and make a self-supervised pretraining sample for MLM & NSP

        Note: You can use any sampling method you know.

        Yields:
        source_sentences: List[int] -- Sampled sentences
        MLM_sentences: List[int] -- Masked sentences
        MLM_mask: List[bool] -- Masking for MLM
        NSP_label: bool -- NSP label which indicates whether the sentences is connected.

        Example: If 25% mask with 50 % <msk> + 25% random + 25% same -- this percentage is just a example.
        source_sentences = ['<cls>', 'He', 'bought', 'a', 'gallon', 'of', 'milk',
                            '<sep>', 'He', 'drank', 'it', 'all', 'on', 'the', 'spot', '<sep>']
        MLM_sentences = ['<cls>', 'He', '<msk>', 'a', 'gallon', 'of, 'milk',
                         '<sep>', 'He', 'drank', 'it', 'tree', 'on', '<msk>', 'spot', '<sep>']
        MLM_mask = [False, False, True, False, False, False, False,
                    False, True, False, False, True, False True, False, False]
        NSP_label = True
        """
        
        while True:
            NSP_label = True if random.random() < 0.5 else False

            if NSP_label:
                sen1, sen2 = self.positive_sampling()
                
            else:
                sen1, sen2 = self.negative_sampling()

            if len(sen1) + len(sen2) > self.max_len-3:
                sen1 = sen1[:(self.max_len-3)//2]
                sen2 = sen2[:(self.max_len-3)//2]
                
            source_sentences = [self.CLS] + sen1 + [self.SEP] + sen2 + [self.SEP]
            MLM_sentences, MLM_mask = self.masking(sen1,sen2)
            
            attention_mask = [1]*len(source_sentences)
            token_type_ids = [0]*(len(sen1)+2) + [1]*(len(sen2)+1)
            
            # Zero padding
            if len(source_sentences) < self.max_len:
                num_pad = self.max_len - len(source_sentences)
                
                source_sentences = source_sentences + [0]*num_pad
                MLM_sentences = MLM_sentences + [0]*num_pad
                MLM_mask = MLM_mask + [0]*num_pad
                attention_mask = attention_mask + [0]*num_pad 
                token_type_ids = token_type_ids + [0]*num_pad 

            
            assert len(source_sentences) == len(MLM_sentences) == len(MLM_mask)
            yield source_sentences, MLM_sentences, attention_mask, token_type_ids, MLM_mask, NSP_label



In [None]:
trainset = PretrainDataset(indexed_dataset, vocab)

In [None]:
dataset = tf.data.Dataset.from_generator(trainset.__iter__, (tf.int32, tf.int32, tf.uint8, tf.int32, tf.uint8, tf.uint8))

batch_size = 32
dataset = dataset.batch(batch_size)

In [None]:
config = BertConfig()
model = TFBertMainLayer(config)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

losses = []
max_step = 100000
for step, (source_sentences, MLM_sentences, attention_mask, token_type_ids, MLM_mask, NSP_label) in enumerate(dataset):
    if step >= max_step:
        break
    #print(source_sentences, MLM_sentences, attention_mask, token_type_ids, MLM_mask, NSP_label)
    
    with tf.GradientTape() as tape:
        output = model(MLM_sentences, attention_mask, token_type_ids, training=True)
        MLM_loss = compute_MLM_loss(source_sentences, MLM_mask, output[0])
        #print(MLM_loss.shape)
        #print(MLM_loss)
        NSP_loss = compute_NSP_loss(labels=NSP_label, logits=output[1])
        losses.append([MLM_loss, NSP_loss])
        loss_value = MLM_loss + NSP_loss

    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    if step % 100 == 0:
        print('original sentence', np.array(vocab)[source_sentences[0]])
        print('masked sentence', np.array(vocab)[MLM_sentences[0]])
        print('mask', MLM_mask[0])
        print('predicted sentence', np.array(vocab)[tf.keras.backend.argmax(output[0], axis=-1)[0]])

        print(
            "Training loss (for one batch) at step %d: %.4f"
            % (step, float(loss_value))
        )
        print("Seen so far: %s samples" % ((step + 1) * batch_size))

tf.Tensor(
[[  101  1045  2245 ...     0     0     0]
 [  101  2012  1037 ...     0     0     0]
 [  101  2035 29624 ...     0     0     0]
 ...
 [  101  2138  1997 ...     0     0     0]
 [  101  2009  2987 ...     0     0     0]
 [  101  2026  3694 ...     0     0     0]], shape=(10, 128), dtype=int16) tf.Tensor(
[[  101  1045  2245 ...     0     0     0]
 [  101  2012  1037 ...     0     0     0]
 [  101  2035 29624 ...     0     0     0]
 ...
 [  101  2138  1997 ...     0     0     0]
 [  101  2009  2987 ...     0     0     0]
 [  101   103  3694 ...     0     0     0]], shape=(10, 128), dtype=int16) tf.Tensor(
[[ True  True  True ... False False False]
 [ True  True  True ... False False False]
 [ True  True  True ... False False False]
 ...
 [ True  True  True ... False False False]
 [ True  True  True ... False False False]
 [ True  True  True ... False False False]], shape=(10, 128), dtype=bool) tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ..

In [None]:
token_type_ids[1]

<tf.Tensor: id=51, shape=(128,), dtype=int16, numpy=
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int16)>

In [None]:
source_sentences[1]

<tf.Tensor: id=55, shape=(128,), dtype=int16, numpy=
array([  101,  2012,  1037,  2051,  2043,  2034, 29624, 27576, 28310,
        2066, 27785,  3523,  5196,  1998,  4895, 22852,  2977,  2024,
       18661,  2075,  2035,  1996,  3086,  1997,  3274, 27911,  2015,
       29623,  8425,  7357,  2024,  1037,  5996,  8843, 29625,   102,
        2007,  2307, 26136,  1998, 17211, 29623,  1996,  8364,  1997,
       10608,  2479,  2003,  1037,  2208,  2008,  2111,  1997,  2035,
        2287,  2967,  2052,  5959, 29625,   102,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,   

In [None]:
decoding_index = [vocab[i] for i in source_sentences[1] if i!=0]

In [None]:
decoding_index

['[CLS]',
 'at',
 'a',
 'time',
 'when',
 'first',
 '##-',
 '##person',
 'shooters',
 'like',
 'quake',
 'iii',
 'arena',
 'and',
 'un',
 '##real',
 'tournament',
 'are',
 'garner',
 '##ing',
 'all',
 'the',
 'attention',
 'of',
 'computer',
 'gamer',
 '##s',
 '##,',
 'graphic',
 'adventures',
 'are',
 'a',
 'dying',
 'breed',
 '##.',
 '[SEP]',
 'with',
 'great',
 'pun',
 'and',
 'humour',
 '##,',
 'the',
 'curse',
 'of',
 'monkey',
 'island',
 'is',
 'a',
 'game',
 'that',
 'people',
 'of',
 'all',
 'age',
 'groups',
 'would',
 'enjoy',
 '##.',
 '[SEP]']