In [1]:
# !pip install transformers datasets tokenizers
# !wget http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
# !unzip -qq cornell_movie_dialogs_corpus.zip
# !rm cornell_movie_dialogs_corpus.zip
# !mkdir datasets
# !mv cornell\ movie-dialogs\ corpus/movie_conversations.txt ./datasets
# !mv cornell\ movie-dialogs\ corpus/movie_lines.txt ./datasets

### Trained on two objectives
- Next Sentence Prediction
- Masked Language Modeling i.e. masked word prediction

In [1]:
import os
from pathlib import Path
import torch
import re
import random
import transformers, datasets
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer
import tqdm
from torch.utils.data import Dataset, DataLoader
import itertools
import math
import torch.nn.functional as F
import numpy as np
from torch.optim import Adam

# 1 ) Tokenization (Word Piece Tokenizer)

[Huggingface WordPieceTokenizer](https://huggingface.co/learn/nlp-course/chapter6/6?fw=pt)

The tokenizer's primary job is to split the input text into smaller tokens. These tokens are usually words, subwords (WordPiece tokens), or characters, depending on the specific tokenizer and its configuration.

Subword Tokenization (WordPiece): BERT often uses subword tokenization, where words are further divided into smaller units called subword tokens. For instance, "unhappiness" might be broken down into ["un", "##hap", "##piness"]


By dividing the frequency of the pair by the product of the frequencies of each of its parts, the algorithm prioritizes the merging of pairs where the individual parts are less frequent in the vocabulary.

**score=(freq_of_pair)/(freq_of_first_element×freq_of_second_element)**

## 1.1 Tokenizer from Scratch

In [35]:
from collections import defaultdict
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [3]:
corpus = [
    "This is the Hugging Face Course.",
    "This chapter is about tokenization.",
    "This section shows several tokenizer algorithms.",
    "Hopefully, you will be able to understand how they are trained and generate tokens.",
]

### get the frequency of each word ###
word_freqs = defaultdict(int)
for text in corpus:
    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
    new_words = [word for word, offset in words_with_offsets]
    print(new_words)
    for word in new_words:
        word_freqs[word] += 1

print(f"\nFinal Word Frequency: {word_freqs}")

['This', 'is', 'the', 'Hugging', 'Face', 'Course', '.']
['This', 'chapter', 'is', 'about', 'tokenization', '.']
['This', 'section', 'shows', 'several', 'tokenizer', 'algorithms', '.']
['Hopefully', ',', 'you', 'will', 'be', 'able', 'to', 'understand', 'how', 'they', 'are', 'trained', 'and', 'generate', 'tokens', '.']

Final Word Frequency: defaultdict(<class 'int'>, {'This': 3, 'is': 2, 'the': 1, 'Hugging': 1, 'Face': 1, 'Course': 1, '.': 4, 'chapter': 1, 'about': 1, 'tokenization': 1, 'section': 1, 'shows': 1, 'several': 1, 'tokenizer': 1, 'algorithms': 1, 'Hopefully': 1, ',': 1, 'you': 1, 'will': 1, 'be': 1, 'able': 1, 'to': 1, 'understand': 1, 'how': 1, 'they': 1, 'are': 1, 'trained': 1, 'and': 1, 'generate': 1, 'tokens': 1})


In [4]:
### split all word into alphabet ###
alphabet = []
for word in word_freqs.keys():
    if word[0] not in alphabet:
        alphabet.append(word[0])
    for letter in word[1:]:
        if f"##{letter}" not in alphabet:
            alphabet.append(f"##{letter}")

alphabet.sort()
print(f'All alphabets: {alphabet}')

All alphabets: ['##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y']


In [5]:
### insert special token and subword ###
vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + alphabet.copy()
splits = {word: [c if i == 0 else f"##{c}" for i, c in enumerate(word)] for word in word_freqs.keys()}
print(f'\nSplitted Words: {splits}')


Splitted Words: {'This': ['T', '##h', '##i', '##s'], 'is': ['i', '##s'], 'the': ['t', '##h', '##e'], 'Hugging': ['H', '##u', '##g', '##g', '##i', '##n', '##g'], 'Face': ['F', '##a', '##c', '##e'], 'Course': ['C', '##o', '##u', '##r', '##s', '##e'], '.': ['.'], 'chapter': ['c', '##h', '##a', '##p', '##t', '##e', '##r'], 'about': ['a', '##b', '##o', '##u', '##t'], 'tokenization': ['t', '##o', '##k', '##e', '##n', '##i', '##z', '##a', '##t', '##i', '##o', '##n'], 'section': ['s', '##e', '##c', '##t', '##i', '##o', '##n'], 'shows': ['s', '##h', '##o', '##w', '##s'], 'several': ['s', '##e', '##v', '##e', '##r', '##a', '##l'], 'tokenizer': ['t', '##o', '##k', '##e', '##n', '##i', '##z', '##e', '##r'], 'algorithms': ['a', '##l', '##g', '##o', '##r', '##i', '##t', '##h', '##m', '##s'], 'Hopefully': ['H', '##o', '##p', '##e', '##f', '##u', '##l', '##l', '##y'], ',': [','], 'you': ['y', '##o', '##u'], 'will': ['w', '##i', '##l', '##l'], 'be': ['b', '##e'], 'able': ['a', '##b', '##l', '##e'], 't

In [6]:
 ### compute score for merging ###
def compute_pair_scores(splits):
    letter_freqs = defaultdict(int)
    pair_freqs = defaultdict(int)

    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            letter_freqs[split[0]] += freq
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            letter_freqs[split[i]] += freq
            pair_freqs[pair] += freq
        letter_freqs[split[-1]] += freq

    scores = {
        pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])
        for pair, freq in pair_freqs.items()
    }
    return scores

pair_scores = compute_pair_scores(splits)
print(f'Scores for each Pair: {pair_scores}')

Scores for each Pair: {('T', '##h'): 0.125, ('##h', '##i'): 0.03409090909090909, ('##i', '##s'): 0.02727272727272727, ('i', '##s'): 0.1, ('t', '##h'): 0.03571428571428571, ('##h', '##e'): 0.011904761904761904, ('H', '##u'): 0.1, ('##u', '##g'): 0.05, ('##g', '##g'): 0.0625, ('##g', '##i'): 0.022727272727272728, ('##i', '##n'): 0.01652892561983471, ('##n', '##g'): 0.022727272727272728, ('F', '##a'): 0.14285714285714285, ('##a', '##c'): 0.07142857142857142, ('##c', '##e'): 0.023809523809523808, ('C', '##o'): 0.07692307692307693, ('##o', '##u'): 0.046153846153846156, ('##u', '##r'): 0.022222222222222223, ('##r', '##s'): 0.022222222222222223, ('##s', '##e'): 0.004761904761904762, ('c', '##h'): 0.125, ('##h', '##a'): 0.017857142857142856, ('##a', '##p'): 0.07142857142857142, ('##p', '##t'): 0.07142857142857142, ('##t', '##e'): 0.013605442176870748, ('##e', '##r'): 0.026455026455026454, ('a', '##b'): 0.2, ('##b', '##o'): 0.038461538461538464, ('##u', '##t'): 0.02857142857142857, ('t', '##o')

In [7]:
### finding pair with best score ###
best_pair = ""
max_score = None
for pair, score in pair_scores.items():
    if max_score is None or max_score < score:
        best_pair = pair
        max_score = score

print(best_pair, max_score)
vocab.append("ab")

### merge pair ###
def merge_pair(a, b, splits):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                merge = a + b[2:] if b.startswith("##") else a + b
                split = split[:i] + [merge] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

splits = merge_pair("a", "##b", splits)
print(splits["about"])

('a', '##b') 0.2
['ab', '##o', '##u', '##t']


In [8]:
### keep looping to merge more pair
vocab_size = 70
while len(vocab) < vocab_size:
    scores = compute_pair_scores(splits)
    best_pair, max_score = "", None
    for pair, score in scores.items():
        if max_score is None or max_score < score:
            best_pair = pair
            max_score = score
    splits = merge_pair(*best_pair, splits)
    new_token = (
        best_pair[0] + best_pair[1][2:]
        if best_pair[1].startswith("##")
        else best_pair[0] + best_pair[1]
    )
    vocab.append(new_token)

print(f'Final Vocab: {vocab}')

Final Vocab: ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y', 'ab', '##fu', 'Fa', 'Fac', '##ct', '##ful', '##full', '##fully', 'Th', 'ch', '##hm', 'cha', 'chap', 'chapt', '##thm', 'Hu', 'Hug', 'Hugg', 'sh', 'th', 'is', '##thms', '##za', '##zat', '##ut']


In [9]:
### ro encode a word ###
def encode_word(word):
    tokens = []
    while len(word) > 0:
        i = len(word)
        while i > 0 and word[:i] not in vocab:
            i -= 1
        if i == 0:
            return ["[UNK]"]
        tokens.append(word[:i])
        word = word[i:]
        if len(word) > 0:
            word = f"##{word}"
    return tokens

print(encode_word("Hugging"))
print(encode_word("HOgging"))

['Hugg', '##i', '##n', '##g']
['[UNK]']


## 1.2 Tokenizer Training

In [10]:
### data processing
MAX_LEN = 64

### loading all data into memory
corpus_movie_conv = '../../data/movie_conversations.txt'
corpus_movie_lines = '../../data/movie_lines.txt'
with open(corpus_movie_conv, 'r', encoding='iso-8859-1') as c:
    conv = c.readlines()
with open(corpus_movie_lines, 'r', encoding='iso-8859-1') as l:
    lines = l.readlines()

In [11]:
### splitting text using special lines
lines_dic = {}
for line in lines:
    objects = line.split(" +++$+++ ")
    lines_dic[objects[0]] = objects[-1]

In [12]:
lines_dic

{'L1045': 'They do not!\n',
 'L1044': 'They do to!\n',
 'L985': 'I hope so.\n',
 'L984': 'She okay?\n',
 'L925': "Let's go.\n",
 'L924': 'Wow\n',
 'L872': "Okay -- you're gonna need to learn how to lie.\n",
 'L871': 'No\n',
 'L870': 'I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n',
 'L869': 'Like my fear of wearing pastels?\n',
 'L868': 'The "real you".\n',
 'L867': 'What good stuff?\n',
 'L866': "I figured you'd get to the good stuff eventually.\n",
 'L865': 'Thank God!  If I had to hear one more story about your coiffure...\n',
 'L864': "Me.  This endless ...blonde babble. I'm like, boring myself.\n",
 'L863': 'What crap?\n',
 'L862': 'do you listen to this crap?\n',
 'L861': 'No...\n',
 'L860': 'Then Guillermo says, "If you go any lighter, you\'re gonna look like an extra on 90210."\n',
 'L699': 'You always been this selfish?\n',
 'L698': 'But\n',
 'L697': "Then that's all you had to say.\n",
 'L696': 'Well, no...\n',
 'L695

In [13]:
### generate question answer pairs
pairs = []
for con in conv:
    ids = eval(con.split(" +++$+++ ")[-1])
    for i in range(len(ids)):
        qa_pairs = []

        if i == len(ids) - 1:
            break

        first = lines_dic[ids[i]].strip()
        second = lines_dic[ids[i+1]].strip()

        qa_pairs.append(' '.join(first.split()[:MAX_LEN]))
        qa_pairs.append(' '.join(second.split()[:MAX_LEN]))
        pairs.append(qa_pairs)

In [14]:
# sample
print(pairs[20])

["I really, really, really wanna go, but I can't. Not unless my sister goes.", "I'm workin' on it. But she doesn't seem to be goin' for him."]


In [15]:
# WordPiece tokenizer

### save data as txt file
# os.mkdir('./data')
text_data = []
file_count = 0

for sample in tqdm.tqdm([x[0] for x in pairs]):
    text_data.append(sample)

    # once we hit the 10K mark, save to file
    if len(text_data) == 10000:
        with open(f'../../data/text_{file_count}.txt', 'w', encoding='utf-8') as fp:
            fp.write('\n'.join(text_data))
        text_data = []
        file_count += 1

paths = [str(x) for x in Path('../../data').glob('**/text_*.txt')]
print(len(paths))

100%|██████████████████████████████| 221616/221616 [00:00<00:00, 1934156.93it/s]

22





In [16]:
### training own tokenizer
tokenizer = BertWordPieceTokenizer(
    clean_text=True,
    handle_chinese_chars=False,
    strip_accents=False,
    lowercase=True
)

tokenizer.train(
    files=paths,
    vocab_size=30_000,
    min_frequency=5,
    limit_alphabet=1000,
    wordpieces_prefix='##',
    special_tokens=['[PAD]', '[CLS]', '[SEP]', '[MASK]', '[UNK]']
    )






In [17]:
# os.mkdir('../../data/bert-it-1')
# tokenizer.save_model('../../data/bert-it-1', 'bert-it')
tokenizer = BertTokenizer.from_pretrained('../../data/bert-it-1/bert-it-vocab.txt', local_files_only=True)
token_ids = tokenizer('I like surfboarding!')['input_ids']
print(token_ids)
print(tokenizer.convert_ids_to_tokens(token_ids))

[1, 48, 250, 4038, 3625, 154, 5, 2]
['[CLS]', 'i', 'like', 'surf', '##board', '##ing', '!', '[SEP]']




In [18]:
tokenizer.vocab["[MASK]"]

3

# 2) Pre-processing

In [19]:
class BERTDataset(Dataset):
    def __init__(self, data_pair, tokenizer, seq_len=64):

        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.corpus_lines = len(data_pair)
        self.lines = data_pair

    def __len__(self):
        return self.corpus_lines

    def __getitem__(self, item):

        # Step 1: get random sentence pair, either negative or positive (saved as is_next_label)
        t1, t2, is_next_label = self.get_sent(item)

        # Step 2: replace random words in sentence with mask / random words
        t1_random, t1_label = self.random_word(t1)
        t2_random, t2_label = self.random_word(t2)

        # Step 3: Adding CLS and SEP tokens to the start and end of sentences
        # Adding PAD token for labels
        t1 = [self.tokenizer.vocab['[CLS]']] + t1_random + [self.tokenizer.vocab['[SEP]']]
        t2 = t2_random + [self.tokenizer.vocab['[SEP]']]
        t1_label = [self.tokenizer.vocab['[PAD]']] + t1_label + [self.tokenizer.vocab['[PAD]']]
        t2_label = t2_label + [self.tokenizer.vocab['[PAD]']]

        # Step 4: combine sentence 1 and 2 as one input
        # adding PAD tokens to make the sentence same length as seq_len
        segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
        bert_input = (t1 + t2)[:self.seq_len]
        bert_label = (t1_label + t2_label)[:self.seq_len]
        padding = [self.tokenizer.vocab['[PAD]'] for _ in range(self.seq_len - len(bert_input))]
        bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)

        output = {"bert_input": bert_input,
                  "bert_label": bert_label,
                  "segment_label": segment_label,
                  "is_next": is_next_label}

        return {key: torch.tensor(value).to("cuda") for key, value in output.items()}

    def random_word(self, sentence):
        tokens = sentence.split()
        output_label = []
        output = []

        # 15% of the tokens would be replaced
        for i, token in enumerate(tokens):
            prob = random.random()

            # remove cls and sep token
            token_id = self.tokenizer(token)['input_ids'][1:-1]
            #print(token, "--->", token_id)
            # 15% chance of altering token
            if prob < 0.15:
                prob /= 0.15

                # 80% chance change token to mask token
                if prob < 0.8:
                    for i in range(len(token_id)):
                        if i/len(token_id) < prob:
                            output.append(self.tokenizer.vocab['[MASK]'])
                        else:
                            output.append(token_id[i])

                # 10% chance change token to random token
                elif prob < 0.9:
                    for i in range(len(token_id)):
                        if i/len(token_id) < prob:
                            output.append(random.randrange(len(self.tokenizer.vocab)))
                        else:
                            output.append(token_id[i])

                # 10% chance change token to current token
                else:
                    output.append(token_id)

                output_label.append(token_id)

            else:
                output.append(token_id)
                for i in range(len(token_id)):
                    output_label.append(0)

        # flattening
        output = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output]))
        output_label = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output_label]))
        assert len(output) == len(output_label)
        return output, output_label

    def get_sent(self, index):
        '''return random sentence pair'''
        t1, t2 = self.get_corpus_line(index)

        # negative or positive pair, for next sentence prediction
        if random.random() > 0.5:
            return t1, t2, 1
        else:
            return t1, self.get_random_line(), 0

    def get_corpus_line(self, item):
        '''return sentence pair'''
        return self.lines[item][0], self.lines[item][1]

    def get_random_line(self):
        '''return random single sentence'''
        return self.lines[random.randrange(len(self.lines))][1]

In [20]:
# test
print("\n")
train_data = BERTDataset(pairs, seq_len=MAX_LEN, tokenizer=tokenizer)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, pin_memory=False)





In [21]:
sample_data = next(iter(train_loader))
print('Batch Size', sample_data['bert_input'].size())

Batch Size torch.Size([32, 64])


In [22]:
sample_data

{'bert_input': tensor([[    1, 11709,    11,  ...,     0,     0,     0],
         [    1,    48,    11,  ...,     0,     0,     0],
         [    1,   247,   146,  ...,     0,     0,     0],
         ...,
         [    1,   178,     3,  ...,     0,     0,     0],
         [    1,   185,    34,  ...,     0,     0,     0],
         [    1, 10075,   460,  ...,     0,     0,     0]], device='cuda:0'),
 'bert_label': tensor([[  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0],
         ...,
         [  0,   0, 182,  ...,   0,   0,   0],
         [  0, 185,  34,  ...,   0,   0,   0],
         [  0, 146,   0,  ...,   0,   0,   0]], device='cuda:0'),
 'segment_label': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0'),
 'is_next'

In [23]:
tokenizer.convert_ids_to_tokens(sample_data['bert_input'][9])

['[CLS]',
 'hot',
 'dogs',
 '?',
 '[SEP]',
 '[MASK]',
 'don',
 "'",
 't',
 '[MASK]',
 'so',
 '.',
 '[SEP]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]']

In [24]:
# 3 is MASK
result = train_data[random.randrange(len(train_data))]
result

{'bert_input': tensor([    1,   443,     3,     3,    48,   514,   146,   358,   700,    17,
            48,   301,   146, 12389,   393,   211,  4885,  3436,  9685,    17,
             2,     3,    15,   237,    48,  1850,   274,    17,   146,   231,
           410,   162,    17,     3,    17,     2,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0], device='cuda:0'),
 'bert_label': tensor([    0,     0,  1060,    15,     0,     0,     0,     0,     0,     0,
             0,     0,     0, 18633,     0,     0,     0,     0,     0,     0,
             0,   368,    15,     0,     0,     0,     0,     0,     0,     0,
           410,   162,    17,   430,    17,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,    

In [25]:
train_data_ = BERTDataset(pairs[11:12], seq_len=MAX_LEN, tokenizer=tokenizer)

In [26]:
pairs[11]

["Right. See? You're ready for the quiz.",
 "I don't want to know how to say that though. I want to know useful things. Like where the good stores are. How much does champagne cost? Stuff like Chat. I have never in my life had to point out my head to someone."]

In [27]:
train_data_[0]

{'bert_input': tensor([    1,   308,    17,   301,    34,   146,    11,   181,  1144,   202,
           150, 19713,    17,     2,     3,   204,    11,    59,   258,   153,
           210,   268,   153,   311,   173,  1342,    17,    48,   258,   153,
           210,  5690,   583,    17,   250,   333,   150,   334,  9108,   235,
            17,   268,     3,   420,  4501,  1740,    34,     3,   250,  4151,
            17,    48,   217,   408,   171,   218,   552,   387,   153,   972,
           254,   218,   731,     3], device='cuda:0'),
 'bert_label': tensor([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          48,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0, 153,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         493,   0,   0,   0,   0, 919,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0, 153], device='cuda:0'),
 'segment_label': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

# 3) Modeling

In [28]:
### embedding
class PositionalEmbedding_(torch.nn.Module):

    def __init__(self, d_model, max_len=128):
        super().__init__()

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        for pos in range(max_len):
            # for each dimension of the each position
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))

        # include the batch size
        self.pe = pe.unsqueeze(0)
        # self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe
    

class PositionalEmbedding(torch.nn.Module):
    def __init__(self, d_model, max_len=128):
        super().__init__()
        pe = torch.zeros(max_len, d_model).float()
        for pos in range(max_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
        self.register_buffer('pe', pe.unsqueeze(0))  # Ensures correct device handling

    def forward(self, x):
        return self.pe[:, :x.size(1), :].to(x.device) 
    

class BERTEmbedding(torch.nn.Module):
    """
    BERT Embedding which is consisted with under features
        1. TokenEmbedding : normal embedding matrix
        2. PositionalEmbedding : adding positional information using sin, cos
        2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2)
        sum of all these features are output of BERTEmbedding
    """

    def __init__(self, vocab_size, embed_size, seq_len=128, dropout=0.1):
        """
        :param vocab_size: total vocab size
        :param embed_size: embedding size of token embedding
        :param dropout: dropout rate
        """

        super().__init__()
        self.embed_size = embed_size
        # (m, seq_len) --> (m, seq_len, embed_size)
        # padding_idx is not updated during training, remains as fixed pad (0)
        self.token = torch.nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.segment = torch.nn.Embedding(3, embed_size, padding_idx=0)
        self.position = PositionalEmbedding(d_model=embed_size, max_len=seq_len)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, sequence, segment_label):
        x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
        return self.dropout(x)

### testing
embed_layer = BERTEmbedding(vocab_size=len(tokenizer.vocab), embed_size=768, seq_len=MAX_LEN).to("cuda")
embed_result = embed_layer(sample_data['bert_input'], sample_data['segment_label'])
print(embed_result.size())

torch.Size([32, 64, 768])


In [29]:
### attention layers
class MultiHeadedAttention(torch.nn.Module):

    def __init__(self, heads, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()

        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = torch.nn.Dropout(dropout)

        self.query = torch.nn.Linear(d_model, d_model)
        self.key = torch.nn.Linear(d_model, d_model)
        self.value = torch.nn.Linear(d_model, d_model)
        self.output_linear = torch.nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask):
        """
        query, key, value of shape: (batch_size, max_len, d_model)
        mask of shape: (batch_size, 1, 1, max_words)
        """
        # (batch_size, max_len, d_model)
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        # (batch_size, max_len, d_model) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)

        # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(query.size(-1))

        # fill 0 mask with super small number so it wont affect the softmax weight
        # (batch_size, h, max_len, max_len)
        scores = scores.masked_fill(mask == 0, -1e9)

        # (batch_size, h, max_len, max_len)
        # softmax to put attention weight for all non-pad tokens
        # max_len X max_len matrix of attention
        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)

        # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(weights, value)

        # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, d_model)
        context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)

        # (batch_size, max_len, d_model)
        return self.output_linear(context)

class FeedForward(torch.nn.Module):
    "Implements FFN equation"

    def __init__(self, d_model, middle_dim=2048, dropout=0.1):
        super(FeedForward, self).__init__()

        self.fc1 = torch.nn.Linear(d_model, middle_dim)
        self.fc2 = torch.nn.Linear(middle_dim, d_model)
        self.dropout = torch.nn.Dropout(dropout)
        self.activation = torch.nn.GELU()

    def forward(self, x):
        out = self.activation(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

class EncoderLayer(torch.nn.Module):
    def __init__(
        self,
        d_model=768,
        heads=12,
        feed_forward_hidden=768 * 4,
        dropout=0.1
        ):
        super(EncoderLayer, self).__init__()
        self.layernorm = torch.nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadedAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model, middle_dim=feed_forward_hidden)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, embeddings, mask):
        # embeddings: (batch_size, max_len, d_model)
        # encoder mask: (batch_size, 1, 1, max_len)
        # result: (batch_size, max_len, d_model)
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        # residual layer
        interacted = self.layernorm(interacted + embeddings)
        # bottleneck
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

### testing
mask = (sample_data['bert_input'] > 0).unsqueeze(1).repeat(1, sample_data['bert_input'].size(1), 1).unsqueeze(1)
transformer_block = EncoderLayer().to("cuda")
transformer_result = transformer_block(embed_result, mask)
transformer_result.size()

torch.Size([32, 64, 768])

In [30]:
class BERT(torch.nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """

    def __init__(self, vocab_size, d_model=256, n_layers=8, heads=4, max_len=128, dropout=0.1):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """

        super().__init__()
        self.d_model = d_model
        self.n_layers = n_layers
        self.heads = heads

        # paper noted they used 4*hidden_size for ff_network_hidden_size
        self.feed_forward_hidden = d_model * 4

        # embedding for BERT, sum of positional, segment, token embeddings
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=d_model, seq_len=max_len).to("cuda")

        # multi-layers transformer blocks, deep network
        self.encoder_blocks = torch.nn.ModuleList(
            [EncoderLayer(d_model, heads, d_model * 4, dropout).to("cuda") for _ in range(n_layers)])

    def forward(self, x, segment_info):
        # attention masking for padded token
        # (batch_size, 1, seq_len, seq_len)
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

        # embedding the indexed sequence to sequence of vectors
        x = self.embedding(x, segment_info)
        
        # running over multiple transformer blocks
        for encoder in self.encoder_blocks:
            x = encoder.forward(x, mask)
        return x

class NextSentencePrediction(torch.nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """

    def __init__(self, hidden):
        """
        :param hidden: BERT model output size
        """
        super().__init__()
        self.linear = torch.nn.Linear(hidden, 2)
        self.softmax = torch.nn.LogSoftmax(dim=-1)

    def forward(self, x):
        # use only the first token which is the [CLS]
        return self.softmax(self.linear(x[:, 0]))

class MaskedLanguageModel(torch.nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self.linear = torch.nn.Linear(hidden, vocab_size)
        self.softmax = torch.nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))

class BERTLM(torch.nn.Module):
    """
    BERT Language Model
    Next Sentence Prediction Model + Masked Language Model
    """

    def __init__(self, bert: BERT, vocab_size):
        """
        :param bert: BERT model which should be trained
        :param vocab_size: total vocab size for masked_lm
        """

        super().__init__()
        self.bert = bert
        self.next_sentence = NextSentencePrediction(self.bert.d_model)
        self.mask_lm = MaskedLanguageModel(self.bert.d_model, vocab_size)

    def forward(self, x, segment_label):
        x = self.bert(x, segment_label)
        return self.next_sentence(x), self.mask_lm(x)

### test
bert_model = BERT(len(tokenizer.vocab)).to("cuda")
bert_result = bert_model(sample_data['bert_input'], sample_data['segment_label'])
print(bert_result.size())

bert_lm = BERTLM(bert_model, len(tokenizer.vocab)).to("cuda")
final_result = bert_lm(sample_data['bert_input'], sample_data['segment_label'])
print(final_result[0].size(), final_result[1].size())

torch.Size([32, 64, 256])
torch.Size([32, 2]) torch.Size([32, 64, 21304])


# 4) Training

In [31]:
### optimizer
class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(d_model, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

In [34]:
### trainer
class BERTTrainer:
    def __init__(
        self,
        model,
        train_dataloader,
        test_dataloader=None,
        lr= 1e-5,
        weight_decay=0.01,
        betas=(0.9, 0.999),
        warmup_steps=10000,
        log_freq=10,
        device='cuda'
        ):

        self.device = device
        self.model = model
        self.train_data = train_dataloader
        self.test_data = test_dataloader

        # Setting the Adam optimizer with hyper-param
        self.optim = Adam(self.model.parameters(), lr=lr)
        self.optim_schedule = ScheduledOptim(
            self.optim, self.model.bert.d_model, n_warmup_steps=warmup_steps
            )

        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = torch.nn.functional.cross_entropy #torch.nn.NLLLoss(ignore_index=0)
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):

        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        mode = "train" if train else "test"

        # progress bar
        data_iter = tqdm.tqdm(
            enumerate(data_loader),
            desc="EP_%s:%d" % (mode, epoch),
            total=len(data_loader),
            bar_format="{l_bar}{r_bar}"
        )

        for i, data in data_iter:

            # 0. batch_data will be sent into the device(GPU or cpu)
            data = {key: value.to(self.device) for key, value in data.items()}

            # 1. forward the next_sentence_prediction and masked_lm model
            next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])

            # 2-1. NLL(negative log likelihood) loss of is_next classification result
            next_loss = self.criterion(next_sent_output, data["is_next"])

            # 2-2. NLLLoss of predicting masked token word
            # transpose to (m, vocab_size, seq_len) vs (m, seq_len)
            # criterion(mask_lm_output.view(-1, mask_lm_output.size(-1)), data["bert_label"].view(-1))
            mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])

            # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
            loss = next_loss + mask_loss

            # 3. backward and optimization only in train
            if train:
                self.optim_schedule.zero_grad()
                loss.backward()
                self.optim_schedule.step_and_update_lr()

            # next sentence prediction accuracy
            correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            total_element += data["is_next"].nelement()

            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))
        print(
            f"EP{epoch}, {mode}: \
            avg_loss={avg_loss / len(data_iter)}, \
            total_acc={total_correct * 100.0 / total_element}"
        )

### test
train_data = BERTDataset(pairs, seq_len=MAX_LEN, tokenizer=tokenizer)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, pin_memory=False)
bert_model = BERT(len(tokenizer.vocab)).to("cuda")
bert_lm = BERTLM(bert_model, len(tokenizer.vocab)).to("cuda")
bert_trainer = BERTTrainer(bert_lm, train_loader, device='cuda')
epochs = 2

for epoch in range(epochs):
    bert_trainer.train(epoch)

Total Parameters: 17244218


EP_train:0:   0%|| 1/3463 [00:00<17:41,  3.26it/s]

{'epoch': 0, 'iter': 0, 'avg_loss': 11.064160346984863, 'avg_acc': 48.4375, 'loss': 11.064160346984863}


EP_train:0:   0%|| 11/3463 [00:02<15:04,  3.82it/s]

{'epoch': 0, 'iter': 10, 'avg_loss': 11.042526591907848, 'avg_acc': 49.715909090909086, 'loss': 10.989319801330566}


EP_train:0:   1%|| 21/3463 [00:05<14:40,  3.91it/s]

{'epoch': 0, 'iter': 20, 'avg_loss': 10.952872821262904, 'avg_acc': 49.107142857142854, 'loss': 10.732012748718262}


EP_train:0:   1%|| 31/3463 [00:08<14:37,  3.91it/s]

{'epoch': 0, 'iter': 30, 'avg_loss': 10.81309066280242, 'avg_acc': 49.596774193548384, 'loss': 10.323813438415527}


EP_train:0:   1%|| 41/3463 [00:10<15:02,  3.79it/s]

{'epoch': 0, 'iter': 40, 'avg_loss': 10.620949303231589, 'avg_acc': 49.885670731707314, 'loss': 9.710350036621094}


EP_train:0:   1%|| 51/3463 [00:13<15:32,  3.66it/s]

{'epoch': 0, 'iter': 50, 'avg_loss': 10.377413001714968, 'avg_acc': 50.18382352941176, 'loss': 9.083040237426758}


EP_train:0:   2%|| 61/3463 [00:16<15:44,  3.60it/s]

{'epoch': 0, 'iter': 60, 'avg_loss': 10.094384881316639, 'avg_acc': 50.435450819672134, 'loss': 8.253893852233887}


EP_train:0:   2%|| 71/3463 [00:19<15:28,  3.65it/s]

{'epoch': 0, 'iter': 70, 'avg_loss': 9.764720191418284, 'avg_acc': 50.63820422535211, 'loss': 7.337224960327148}


EP_train:0:   2%|| 81/3463 [00:21<14:49,  3.80it/s]

{'epoch': 0, 'iter': 80, 'avg_loss': 9.403566766668249, 'avg_acc': 50.88734567901234, 'loss': 6.45443058013916}


EP_train:0:   3%|| 91/3463 [00:24<15:38,  3.59it/s]

{'epoch': 0, 'iter': 90, 'avg_loss': 9.031156277918553, 'avg_acc': 50.94436813186813, 'loss': 5.697570323944092}


EP_train:0:   3%|| 101/3463 [00:27<15:30,  3.61it/s]

{'epoch': 0, 'iter': 100, 'avg_loss': 8.651605417232702, 'avg_acc': 50.58787128712871, 'loss': 4.902891159057617}


EP_train:0:   3%|| 111/3463 [00:29<14:57,  3.73it/s]

{'epoch': 0, 'iter': 110, 'avg_loss': 8.274990867924046, 'avg_acc': 50.54898648648649, 'loss': 4.161083698272705}


EP_train:0:   3%|| 121/3463 [00:32<15:14,  3.65it/s]

{'epoch': 0, 'iter': 120, 'avg_loss': 7.912814902865197, 'avg_acc': 50.503615702479344, 'loss': 3.644195795059204}


EP_train:0:   4%|| 131/3463 [00:35<15:35,  3.56it/s]

{'epoch': 0, 'iter': 130, 'avg_loss': 7.575759179719531, 'avg_acc': 50.50095419847328, 'loss': 3.22108793258667}


EP_train:0:   4%|| 141/3463 [00:38<15:19,  3.61it/s]

{'epoch': 0, 'iter': 140, 'avg_loss': 7.265835645351004, 'avg_acc': 50.47650709219859, 'loss': 3.1240081787109375}


EP_train:0:   4%|| 151/3463 [00:41<15:24,  3.58it/s]

{'epoch': 0, 'iter': 150, 'avg_loss': 6.981944695213772, 'avg_acc': 50.403559602649004, 'loss': 2.8914730548858643}


EP_train:0:   5%|| 161/3463 [00:43<15:20,  3.59it/s]

{'epoch': 0, 'iter': 160, 'avg_loss': 6.71947419421273, 'avg_acc': 50.38819875776398, 'loss': 2.731844902038574}


EP_train:0:   5%|| 171/3463 [00:46<15:14,  3.60it/s]

{'epoch': 0, 'iter': 170, 'avg_loss': 6.48205214076572, 'avg_acc': 50.447733918128655, 'loss': 2.5523569583892822}


EP_train:0:   5%|| 181/3463 [00:49<15:59,  3.42it/s]

{'epoch': 0, 'iter': 180, 'avg_loss': 6.265881206449224, 'avg_acc': 50.58701657458563, 'loss': 2.5329079627990723}


EP_train:0:   6%|| 191/3463 [00:52<15:52,  3.44it/s]

{'epoch': 0, 'iter': 190, 'avg_loss': 6.06890596769243, 'avg_acc': 50.646269633507856, 'loss': 2.5480520725250244}


EP_train:0:   6%|| 201/3463 [00:55<14:55,  3.64it/s]

{'epoch': 0, 'iter': 200, 'avg_loss': 5.887733497429843, 'avg_acc': 50.60634328358209, 'loss': 2.4269275665283203}


EP_train:0:   6%|| 211/3463 [00:58<15:47,  3.43it/s]

{'epoch': 0, 'iter': 210, 'avg_loss': 5.721935821370491, 'avg_acc': 50.770142180094794, 'loss': 2.4148659706115723}


EP_train:0:   6%|| 221/3463 [01:00<14:43,  3.67it/s]

{'epoch': 0, 'iter': 220, 'avg_loss': 5.569748076917898, 'avg_acc': 50.664592760180994, 'loss': 2.3990750312805176}


EP_train:0:   7%|| 231/3463 [01:03<14:54,  3.61it/s]

{'epoch': 0, 'iter': 230, 'avg_loss': 5.427403742100769, 'avg_acc': 50.62229437229438, 'loss': 2.2664451599121094}


EP_train:0:   7%|| 241/3463 [01:06<15:02,  3.57it/s]

{'epoch': 0, 'iter': 240, 'avg_loss': 5.295655626479026, 'avg_acc': 50.615923236514526, 'loss': 2.1995253562927246}


EP_train:0:   7%|| 251/3463 [01:09<14:57,  3.58it/s]

{'epoch': 0, 'iter': 250, 'avg_loss': 5.173335044032549, 'avg_acc': 50.535358565737056, 'loss': 2.256549119949341}


EP_train:0:   8%|| 261/3463 [01:12<14:36,  3.65it/s]

{'epoch': 0, 'iter': 260, 'avg_loss': 5.057576077194506, 'avg_acc': 50.71240421455939, 'loss': 2.0968105792999268}


EP_train:0:   8%|| 271/3463 [01:14<14:41,  3.62it/s]

{'epoch': 0, 'iter': 270, 'avg_loss': 4.948595905655864, 'avg_acc': 50.743773062730625, 'loss': 2.0928256511688232}


EP_train:0:   8%|| 281/3463 [01:17<14:16,  3.71it/s]

{'epoch': 0, 'iter': 280, 'avg_loss': 4.844958416931994, 'avg_acc': 50.70062277580071, 'loss': 2.063168525695801}


EP_train:0:   8%|| 291/3463 [01:20<15:03,  3.51it/s]

{'epoch': 0, 'iter': 290, 'avg_loss': 4.748857915606286, 'avg_acc': 50.69802405498282, 'loss': 2.2306814193725586}


EP_train:0:   9%|| 301/3463 [01:23<15:24,  3.42it/s]

{'epoch': 0, 'iter': 300, 'avg_loss': 4.658086286826784, 'avg_acc': 50.814991694352166, 'loss': 2.0115902423858643}


EP_train:0:   9%|| 311/3463 [01:25<14:10,  3.71it/s]

{'epoch': 0, 'iter': 310, 'avg_loss': 4.571768634572289, 'avg_acc': 50.854099678456585, 'loss': 1.8984904289245605}


EP_train:0:   9%|| 321/3463 [01:28<14:49,  3.53it/s]

{'epoch': 0, 'iter': 320, 'avg_loss': 4.490426055738859, 'avg_acc': 50.803154205607484, 'loss': 1.9634464979171753}


EP_train:0:  10%|| 331/3463 [01:31<14:45,  3.54it/s]

{'epoch': 0, 'iter': 330, 'avg_loss': 4.412451852844561, 'avg_acc': 50.807212990936556, 'loss': 1.8468270301818848}


EP_train:0:  10%|| 341/3463 [01:34<14:30,  3.59it/s]

{'epoch': 0, 'iter': 340, 'avg_loss': 4.337791433082647, 'avg_acc': 50.81103372434017, 'loss': 1.8333005905151367}


EP_train:0:  10%|| 351/3463 [01:37<14:13,  3.65it/s]

{'epoch': 0, 'iter': 350, 'avg_loss': 4.266229127207373, 'avg_acc': 50.81463675213676, 'loss': 1.8416192531585693}


EP_train:0:  10%|| 361/3463 [01:39<14:42,  3.51it/s]

{'epoch': 0, 'iter': 360, 'avg_loss': 4.199109389841391, 'avg_acc': 50.76610110803325, 'loss': 1.809220314025879}


EP_train:0:  11%|| 371/3463 [01:42<14:21,  3.59it/s]

{'epoch': 0, 'iter': 370, 'avg_loss': 4.1335630741402145, 'avg_acc': 50.787567385444746, 'loss': 1.7230408191680908}


EP_train:0:  11%|| 381/3463 [01:45<14:59,  3.43it/s]

{'epoch': 0, 'iter': 380, 'avg_loss': 4.071981276114156, 'avg_acc': 50.799704724409445, 'loss': 1.7700457572937012}


EP_train:0:  11%|| 391/3463 [01:48<14:20,  3.57it/s]

{'epoch': 0, 'iter': 390, 'avg_loss': 4.012487551745246, 'avg_acc': 50.80322890025576, 'loss': 1.7577838897705078}


EP_train:0:  12%|| 401/3463 [01:51<14:18,  3.57it/s]

{'epoch': 0, 'iter': 400, 'avg_loss': 3.955343697433757, 'avg_acc': 50.76371571072319, 'loss': 1.7360895872116089}


EP_train:0:  12%|| 411/3463 [01:54<16:04,  3.17it/s]

{'epoch': 0, 'iter': 410, 'avg_loss': 3.900573527435897, 'avg_acc': 50.669099756691, 'loss': 1.7131435871124268}


EP_train:0:  12%|| 421/3463 [01:56<14:12,  3.57it/s]

{'epoch': 0, 'iter': 420, 'avg_loss': 3.8484418547918, 'avg_acc': 50.638361045130644, 'loss': 1.6981916427612305}


EP_train:0:  12%|| 431/3463 [01:59<14:20,  3.52it/s]

{'epoch': 0, 'iter': 430, 'avg_loss': 3.7980961368144803, 'avg_acc': 50.605423433874705, 'loss': 1.6600944995880127}


EP_train:0:  13%|| 441/3463 [02:02<14:02,  3.59it/s]

{'epoch': 0, 'iter': 440, 'avg_loss': 3.749583765763004, 'avg_acc': 50.577522675736965, 'loss': 1.5634042024612427}


EP_train:0:  13%|| 451/3463 [02:05<13:51,  3.62it/s]

{'epoch': 0, 'iter': 450, 'avg_loss': 3.703138023416642, 'avg_acc': 50.50235587583148, 'loss': 1.7224621772766113}


EP_train:0:  13%|| 461/3463 [02:08<13:46,  3.63it/s]

{'epoch': 0, 'iter': 460, 'avg_loss': 3.6580476104049517, 'avg_acc': 50.508405639913235, 'loss': 1.60787034034729}


EP_train:0:  14%|| 471/3463 [02:11<14:00,  3.56it/s]

{'epoch': 0, 'iter': 470, 'avg_loss': 3.6136383623074573, 'avg_acc': 50.43789808917197, 'loss': 1.6017255783081055}


EP_train:0:  14%|| 481/3463 [02:13<13:56,  3.56it/s]

{'epoch': 0, 'iter': 480, 'avg_loss': 3.570725475923931, 'avg_acc': 50.432042619542614, 'loss': 1.6276061534881592}


EP_train:0:  14%|| 491/3463 [02:16<13:54,  3.56it/s]

{'epoch': 0, 'iter': 490, 'avg_loss': 3.5299169793876755, 'avg_acc': 50.40414969450102, 'loss': 1.606286644935608}


EP_train:0:  14%|| 501/3463 [02:19<13:41,  3.61it/s]

{'epoch': 0, 'iter': 500, 'avg_loss': 3.4910830625754867, 'avg_acc': 50.358657684630735, 'loss': 1.608565092086792}


EP_train:0:  15%|| 511/3463 [02:22<13:29,  3.65it/s]

{'epoch': 0, 'iter': 510, 'avg_loss': 3.453098151781788, 'avg_acc': 50.34246575342466, 'loss': 1.558274745941162}


EP_train:0:  15%|| 521/3463 [02:25<15:31,  3.16it/s]

{'epoch': 0, 'iter': 520, 'avg_loss': 3.4168233102663006, 'avg_acc': 50.437859884836854, 'loss': 1.5082417726516724}


EP_train:0:  15%|| 531/3463 [02:27<13:33,  3.60it/s]

{'epoch': 0, 'iter': 530, 'avg_loss': 3.381413204297283, 'avg_acc': 50.43255649717514, 'loss': 1.594383716583252}


EP_train:0:  16%|| 541/3463 [02:30<14:15,  3.42it/s]

{'epoch': 0, 'iter': 540, 'avg_loss': 3.3473126297737448, 'avg_acc': 50.447666358595185, 'loss': 1.594757080078125}


EP_train:0:  16%|| 551/3463 [02:33<13:11,  3.68it/s]

{'epoch': 0, 'iter': 550, 'avg_loss': 3.314580990484968, 'avg_acc': 50.46506352087115, 'loss': 1.4976401329040527}


EP_train:0:  16%|| 561/3463 [02:36<13:05,  3.69it/s]

{'epoch': 0, 'iter': 560, 'avg_loss': 3.2819234683126903, 'avg_acc': 50.453988413547236, 'loss': 1.4764823913574219}


EP_train:0:  16%|| 571/3463 [02:39<13:41,  3.52it/s]

{'epoch': 0, 'iter': 570, 'avg_loss': 3.250801617543877, 'avg_acc': 50.470665499124344, 'loss': 1.5681931972503662}


EP_train:0:  17%|| 581/3463 [02:41<13:18,  3.61it/s]

{'epoch': 0, 'iter': 580, 'avg_loss': 3.2205489580479423, 'avg_acc': 50.47870051635112, 'loss': 1.4308552742004395}


EP_train:0:  17%|| 591/3463 [02:44<14:13,  3.36it/s]

{'epoch': 0, 'iter': 590, 'avg_loss': 3.191144521264659, 'avg_acc': 50.452093908629436, 'loss': 1.5001676082611084}


EP_train:0:  17%|| 601/3463 [02:47<13:28,  3.54it/s]

{'epoch': 0, 'iter': 600, 'avg_loss': 3.1624712515591384, 'avg_acc': 50.49656821963394, 'loss': 1.423638105392456}


EP_train:0:  18%|| 611/3463 [02:50<13:36,  3.50it/s]

{'epoch': 0, 'iter': 610, 'avg_loss': 3.134768659276385, 'avg_acc': 50.508899345335514, 'loss': 1.4674564599990845}


EP_train:0:  18%|| 621/3463 [02:53<13:17,  3.56it/s]

{'epoch': 0, 'iter': 620, 'avg_loss': 3.107339857089155, 'avg_acc': 50.40760869565217, 'loss': 1.4554481506347656}


EP_train:0:  18%|| 631/3463 [02:56<13:54,  3.39it/s]

{'epoch': 0, 'iter': 630, 'avg_loss': 3.0805273828718063, 'avg_acc': 50.41848256735341, 'loss': 1.4262667894363403}


EP_train:0:  19%|| 641/3463 [02:58<12:39,  3.71it/s]

{'epoch': 0, 'iter': 640, 'avg_loss': 3.0543331871166615, 'avg_acc': 50.40220358814352, 'loss': 1.3779809474945068}


EP_train:0:  19%|| 651/3463 [03:01<13:09,  3.56it/s]

{'epoch': 0, 'iter': 650, 'avg_loss': 3.0292575562604562, 'avg_acc': 50.403225806451616, 'loss': 1.4157928228378296}


EP_train:0:  19%|| 661/3463 [03:04<13:45,  3.39it/s]

{'epoch': 0, 'iter': 660, 'avg_loss': 3.005180526248626, 'avg_acc': 50.34512102874432, 'loss': 1.3684340715408325}


EP_train:0:  19%|| 671/3463 [03:07<13:43,  3.39it/s]

{'epoch': 0, 'iter': 670, 'avg_loss': 2.9817725976071308, 'avg_acc': 50.36093517138599, 'loss': 1.4405956268310547}


EP_train:0:  20%|| 681/3463 [03:10<13:31,  3.43it/s]

{'epoch': 0, 'iter': 680, 'avg_loss': 2.959290549443507, 'avg_acc': 50.37628487518355, 'loss': 1.512922763824463}


EP_train:0:  20%|| 691/3463 [03:13<12:37,  3.66it/s]

{'epoch': 0, 'iter': 690, 'avg_loss': 2.9366892143199825, 'avg_acc': 50.43415340086831, 'loss': 1.366170883178711}


EP_train:0:  20%|| 701/3463 [03:16<12:25,  3.70it/s]

{'epoch': 0, 'iter': 700, 'avg_loss': 2.9149484464343365, 'avg_acc': 50.44579172610556, 'loss': 1.3542821407318115}


EP_train:0:  21%|| 711/3463 [03:18<12:26,  3.69it/s]

{'epoch': 0, 'iter': 710, 'avg_loss': 2.8943546929942907, 'avg_acc': 50.4395218002813, 'loss': 1.4410353899002075}


EP_train:0:  21%|| 721/3463 [03:21<12:44,  3.59it/s]

{'epoch': 0, 'iter': 720, 'avg_loss': 2.87339600594133, 'avg_acc': 50.452929958391124, 'loss': 1.3356707096099854}


EP_train:0:  21%|| 731/3463 [03:24<13:14,  3.44it/s]

{'epoch': 0, 'iter': 730, 'avg_loss': 2.8534834681718357, 'avg_acc': 50.40825923392613, 'loss': 1.5284442901611328}


EP_train:0:  21%|| 741/3463 [03:27<13:03,  3.48it/s]

{'epoch': 0, 'iter': 740, 'avg_loss': 2.8339926832922395, 'avg_acc': 50.45335695006747, 'loss': 1.4517512321472168}


EP_train:0:  22%|| 751/3463 [03:30<12:19,  3.67it/s]

{'epoch': 0, 'iter': 750, 'avg_loss': 2.814842209517559, 'avg_acc': 50.45148135818908, 'loss': 1.2829430103302002}


EP_train:0:  22%|| 761/3463 [03:33<12:13,  3.68it/s]

{'epoch': 0, 'iter': 760, 'avg_loss': 2.7960522978440534, 'avg_acc': 50.41475032851511, 'loss': 1.3780567646026611}


EP_train:0:  22%|| 771/3463 [03:35<12:09,  3.69it/s]

{'epoch': 0, 'iter': 770, 'avg_loss': 2.777261776064466, 'avg_acc': 50.39721141374838, 'loss': 1.4085432291030884}


EP_train:0:  23%|| 781/3463 [03:38<12:25,  3.60it/s]

{'epoch': 0, 'iter': 780, 'avg_loss': 2.7593118503365415, 'avg_acc': 50.3661171574904, 'loss': 1.3366742134094238}


EP_train:0:  23%|| 791/3463 [03:41<12:20,  3.61it/s]

{'epoch': 0, 'iter': 790, 'avg_loss': 2.74227611289163, 'avg_acc': 50.32988305941846, 'loss': 1.3747334480285645}


EP_train:0:  23%|| 801/3463 [03:44<12:05,  3.67it/s]

{'epoch': 0, 'iter': 800, 'avg_loss': 2.7256634820266608, 'avg_acc': 50.34917290886391, 'loss': 1.5586915016174316}


EP_train:0:  23%|| 811/3463 [03:46<12:16,  3.60it/s]

{'epoch': 0, 'iter': 810, 'avg_loss': 2.708668427608457, 'avg_acc': 50.33523427866831, 'loss': 1.362011194229126}


EP_train:0:  24%|| 821/3463 [03:49<12:30,  3.52it/s]

{'epoch': 0, 'iter': 820, 'avg_loss': 2.692187973991236, 'avg_acc': 50.28547503045067, 'loss': 1.3111262321472168}


EP_train:0:  24%|| 831/3463 [03:52<12:57,  3.38it/s]

{'epoch': 0, 'iter': 830, 'avg_loss': 2.675907356070554, 'avg_acc': 50.312123947051745, 'loss': 1.328399896621704}


EP_train:0:  24%|| 841/3463 [03:55<12:08,  3.60it/s]

{'epoch': 0, 'iter': 840, 'avg_loss': 2.6604885272548824, 'avg_acc': 50.32141795481569, 'loss': 1.2746193408966064}


EP_train:0:  25%|| 851/3463 [03:58<12:14,  3.56it/s]

{'epoch': 0, 'iter': 850, 'avg_loss': 2.645149083871539, 'avg_acc': 50.31029670975323, 'loss': 1.3657395839691162}


EP_train:0:  25%|| 861/3463 [04:01<12:56,  3.35it/s]

{'epoch': 0, 'iter': 860, 'avg_loss': 2.6305550598516696, 'avg_acc': 50.31213704994193, 'loss': 1.3961849212646484}


EP_train:0:  25%|| 871/3463 [04:04<12:38,  3.42it/s]

{'epoch': 0, 'iter': 870, 'avg_loss': 2.6162464286101264, 'avg_acc': 50.30675947187141, 'loss': 1.3866848945617676}


EP_train:0:  25%|| 881/3463 [04:07<13:03,  3.30it/s]

{'epoch': 0, 'iter': 880, 'avg_loss': 2.6019778825368025, 'avg_acc': 50.299730419977294, 'loss': 1.285130500793457}


EP_train:0:  26%|| 891/3463 [04:10<12:03,  3.56it/s]

{'epoch': 0, 'iter': 890, 'avg_loss': 2.5879397659858334, 'avg_acc': 50.292859147025816, 'loss': 1.28025221824646}


EP_train:0:  26%|| 901/3463 [04:12<11:53,  3.59it/s]

{'epoch': 0, 'iter': 900, 'avg_loss': 2.5739492794517407, 'avg_acc': 50.27920366259712, 'loss': 1.2865521907806396}


EP_train:0:  26%|| 911/3463 [04:15<11:41,  3.64it/s]

{'epoch': 0, 'iter': 910, 'avg_loss': 2.560358509143281, 'avg_acc': 50.27785400658617, 'loss': 1.2998626232147217}


EP_train:0:  27%|| 921/3463 [04:18<11:40,  3.63it/s]

{'epoch': 0, 'iter': 920, 'avg_loss': 2.5468354341649855, 'avg_acc': 50.283319761129206, 'loss': 1.286390781402588}


EP_train:0:  27%|| 931/3463 [04:21<11:25,  3.69it/s]

{'epoch': 0, 'iter': 930, 'avg_loss': 2.533210014938412, 'avg_acc': 50.243353920515574, 'loss': 1.257272720336914}


EP_train:0:  27%|| 941/3463 [04:23<11:53,  3.54it/s]

{'epoch': 0, 'iter': 940, 'avg_loss': 2.5200879219107875, 'avg_acc': 50.244088735387884, 'loss': 1.2907686233520508}


EP_train:0:  27%|| 951/3463 [04:26<11:25,  3.67it/s]

{'epoch': 0, 'iter': 950, 'avg_loss': 2.5066945141672963, 'avg_acc': 50.262881177707676, 'loss': 1.245567798614502}


EP_train:0:  28%|| 961/3463 [04:29<11:31,  3.62it/s]

{'epoch': 0, 'iter': 960, 'avg_loss': 2.4938028357403583, 'avg_acc': 50.26339750260146, 'loss': 1.2750887870788574}


EP_train:0:  28%|| 971/3463 [04:32<12:09,  3.42it/s]

{'epoch': 0, 'iter': 970, 'avg_loss': 2.4812558176097617, 'avg_acc': 50.236547373841404, 'loss': 1.2847785949707031}


EP_train:0:  28%|| 981/3463 [04:35<11:50,  3.49it/s]

{'epoch': 0, 'iter': 980, 'avg_loss': 2.469045634420397, 'avg_acc': 50.24687818552498, 'loss': 1.2826626300811768}


EP_train:0:  29%|| 991/3463 [04:38<12:22,  3.33it/s]

{'epoch': 0, 'iter': 990, 'avg_loss': 2.456616582591406, 'avg_acc': 50.2412336024218, 'loss': 1.3360941410064697}


EP_train:0:  29%|| 1001/3463 [04:40<11:28,  3.57it/s]

{'epoch': 0, 'iter': 1000, 'avg_loss': 2.4447295157225817, 'avg_acc': 50.22633616383616, 'loss': 1.311579942703247}


EP_train:0:  29%|| 1011/3463 [04:43<11:35,  3.53it/s]

{'epoch': 0, 'iter': 1010, 'avg_loss': 2.4327931996033287, 'avg_acc': 50.211733432245296, 'loss': 1.2316865921020508}


EP_train:0:  29%|| 1021/3463 [04:46<13:10,  3.09it/s]

{'epoch': 0, 'iter': 1020, 'avg_loss': 2.421010383779692, 'avg_acc': 50.17446131243879, 'loss': 1.2435944080352783}


EP_train:0:  30%|| 1031/3463 [04:49<11:17,  3.59it/s]

{'epoch': 0, 'iter': 1030, 'avg_loss': 2.409359481239874, 'avg_acc': 50.16973811833172, 'loss': 1.2200098037719727}


EP_train:0:  30%|| 1041/3463 [04:52<11:16,  3.58it/s]

{'epoch': 0, 'iter': 1040, 'avg_loss': 2.3982503834421194, 'avg_acc': 50.14109029779059, 'loss': 1.247941255569458}


EP_train:0:  30%|| 1051/3463 [04:55<11:06,  3.62it/s]

{'epoch': 0, 'iter': 1050, 'avg_loss': 2.3870420992431134, 'avg_acc': 50.13826117982874, 'loss': 1.2000298500061035}


EP_train:0:  31%|| 1061/3463 [04:58<10:39,  3.76it/s]

{'epoch': 0, 'iter': 1060, 'avg_loss': 2.3760672289734175, 'avg_acc': 50.138430725730444, 'loss': 1.1713216304779053}


EP_train:0:  31%|| 1071/3463 [05:00<10:56,  3.64it/s]

{'epoch': 0, 'iter': 1070, 'avg_loss': 2.365300089482611, 'avg_acc': 50.14443277310925, 'loss': 1.262951135635376}


EP_train:0:  31%|| 1081/3463 [05:03<11:02,  3.60it/s]

{'epoch': 0, 'iter': 1080, 'avg_loss': 2.354667057051469, 'avg_acc': 50.141651248843665, 'loss': 1.272630214691162}


EP_train:0:  32%|| 1091/3463 [05:06<11:10,  3.54it/s]

{'epoch': 0, 'iter': 1090, 'avg_loss': 2.344434586060801, 'avg_acc': 50.16183547204399, 'loss': 1.2246350049972534}


EP_train:0:  32%|| 1101/3463 [05:09<10:45,  3.66it/s]

{'epoch': 0, 'iter': 1100, 'avg_loss': 2.3341735950066327, 'avg_acc': 50.127724795640326, 'loss': 1.2125701904296875}


EP_train:0:  32%|| 1111/3463 [05:12<11:46,  3.33it/s]

{'epoch': 0, 'iter': 1110, 'avg_loss': 2.324049287932505, 'avg_acc': 50.129387938793876, 'loss': 1.2234764099121094}


EP_train:0:  32%|| 1121/3463 [05:15<10:58,  3.56it/s]

{'epoch': 0, 'iter': 1120, 'avg_loss': 2.3143255259712077, 'avg_acc': 50.12126449598573, 'loss': 1.1562248468399048}


EP_train:0:  33%|| 1131/3463 [05:17<10:52,  3.57it/s]

{'epoch': 0, 'iter': 1130, 'avg_loss': 2.304487832872136, 'avg_acc': 50.12295534924846, 'loss': 1.111325979232788}


EP_train:0:  33%|| 1141/3463 [05:20<10:35,  3.66it/s]

{'epoch': 0, 'iter': 1140, 'avg_loss': 2.2948139567421153, 'avg_acc': 50.09859772129711, 'loss': 1.1846946477890015}


EP_train:0:  33%|| 1151/3463 [05:23<11:10,  3.45it/s]

{'epoch': 0, 'iter': 1150, 'avg_loss': 2.2854293671822568, 'avg_acc': 50.130321459600346, 'loss': 1.2541110515594482}


EP_train:0:  34%|| 1161/3463 [05:26<10:51,  3.54it/s]

{'epoch': 0, 'iter': 1160, 'avg_loss': 2.2761021644056716, 'avg_acc': 50.11170327304049, 'loss': 1.2434971332550049}


EP_train:0:  34%|| 1171/3463 [05:29<10:36,  3.60it/s]

{'epoch': 0, 'iter': 1170, 'avg_loss': 2.2668288026886247, 'avg_acc': 50.086731426131514, 'loss': 1.1802557706832886}


EP_train:0:  34%|| 1181/3463 [05:32<10:18,  3.69it/s]

{'epoch': 0, 'iter': 1180, 'avg_loss': 2.2577186783524916, 'avg_acc': 50.113780694326834, 'loss': 1.1805131435394287}


EP_train:0:  34%|| 1191/3463 [05:34<10:30,  3.60it/s]

{'epoch': 0, 'iter': 1190, 'avg_loss': 2.248835112446001, 'avg_acc': 50.10888958858103, 'loss': 1.1993134021759033}


EP_train:0:  35%|| 1201/3463 [05:37<10:38,  3.54it/s]

{'epoch': 0, 'iter': 1200, 'avg_loss': 2.239850432846965, 'avg_acc': 50.096273938384684, 'loss': 1.2066700458526611}


EP_train:0:  35%|| 1211/3463 [05:40<10:27,  3.59it/s]

{'epoch': 0, 'iter': 1210, 'avg_loss': 2.231117353273954, 'avg_acc': 50.09289843104872, 'loss': 1.1009774208068848}


EP_train:0:  35%|| 1221/3463 [05:43<10:37,  3.52it/s]

{'epoch': 0, 'iter': 1220, 'avg_loss': 2.2226515340570736, 'avg_acc': 50.09981572481572, 'loss': 1.1882449388504028}


EP_train:0:  36%|| 1231/3463 [05:46<10:06,  3.68it/s]

{'epoch': 0, 'iter': 1230, 'avg_loss': 2.214237671378761, 'avg_acc': 50.10281275385865, 'loss': 1.1781153678894043}


EP_train:0:  36%|| 1241/3463 [05:48<09:54,  3.74it/s]

{'epoch': 0, 'iter': 1240, 'avg_loss': 2.206079930187136, 'avg_acc': 50.09317082997583, 'loss': 1.245970606803894}


EP_train:0:  36%|| 1251/3463 [05:51<10:10,  3.62it/s]

{'epoch': 0, 'iter': 1250, 'avg_loss': 2.1979075444402167, 'avg_acc': 50.093675059952034, 'loss': 1.112560510635376}


EP_train:0:  36%|| 1261/3463 [05:54<10:05,  3.64it/s]

{'epoch': 0, 'iter': 1260, 'avg_loss': 2.190162719477745, 'avg_acc': 50.08178033306899, 'loss': 1.1941791772842407}


EP_train:0:  37%|| 1271/3463 [05:57<09:43,  3.75it/s]

{'epoch': 0, 'iter': 1270, 'avg_loss': 2.1823636913374593, 'avg_acc': 50.08113690007868, 'loss': 1.2143113613128662}


EP_train:0:  37%|| 1281/3463 [05:59<10:10,  3.57it/s]

{'epoch': 0, 'iter': 1280, 'avg_loss': 2.174781060423542, 'avg_acc': 50.069525761124126, 'loss': 1.1712523698806763}


EP_train:0:  37%|| 1291/3463 [06:02<09:55,  3.65it/s]

{'epoch': 0, 'iter': 1290, 'avg_loss': 2.167163279608735, 'avg_acc': 50.10045507358637, 'loss': 1.208480715751648}


EP_train:0:  38%|| 1301/3463 [06:05<09:52,  3.65it/s]

{'epoch': 0, 'iter': 1300, 'avg_loss': 2.159665753894545, 'avg_acc': 50.080466948501154, 'loss': 1.1881229877471924}


EP_train:0:  38%|| 1311/3463 [06:08<09:57,  3.60it/s]

{'epoch': 0, 'iter': 1310, 'avg_loss': 2.1523328715840155, 'avg_acc': 50.07746948893974, 'loss': 1.110529899597168}


EP_train:0:  38%|| 1321/3463 [06:10<09:58,  3.58it/s]

{'epoch': 0, 'iter': 1320, 'avg_loss': 2.145122343351406, 'avg_acc': 50.049678274034825, 'loss': 1.1320679187774658}


EP_train:0:  38%|| 1331/3463 [06:13<09:50,  3.61it/s]

{'epoch': 0, 'iter': 1330, 'avg_loss': 2.137879894336841, 'avg_acc': 50.04930503380917, 'loss': 1.1510555744171143}


EP_train:0:  39%|| 1341/3463 [06:16<10:16,  3.44it/s]

{'epoch': 0, 'iter': 1340, 'avg_loss': 2.1308690274855167, 'avg_acc': 50.044276659209544, 'loss': 1.2081165313720703}


EP_train:0:  39%|| 1351/3463 [06:19<10:14,  3.44it/s]

{'epoch': 0, 'iter': 1350, 'avg_loss': 2.1237339920507017, 'avg_acc': 50.03122686898593, 'loss': 1.224137783050537}


EP_train:0:  39%|| 1361/3463 [06:22<09:43,  3.60it/s]

{'epoch': 0, 'iter': 1360, 'avg_loss': 2.116785938471753, 'avg_acc': 50.01262858192506, 'loss': 1.1473636627197266}


EP_train:0:  40%|| 1371/3463 [06:25<09:33,  3.65it/s]

{'epoch': 0, 'iter': 1370, 'avg_loss': 2.1102195661664442, 'avg_acc': 50.00341903719912, 'loss': 1.1557636260986328}


EP_train:0:  40%|| 1381/3463 [06:27<09:20,  3.71it/s]

{'epoch': 0, 'iter': 1380, 'avg_loss': 2.103677218846011, 'avg_acc': 49.97284576393917, 'loss': 1.1652355194091797}


EP_train:0:  40%|| 1391/3463 [06:30<09:50,  3.51it/s]

{'epoch': 0, 'iter': 1390, 'avg_loss': 2.0971670415571144, 'avg_acc': 49.962931344356576, 'loss': 1.220190167427063}


EP_train:0:  40%|| 1401/3463 [06:33<09:47,  3.51it/s]

{'epoch': 0, 'iter': 1400, 'avg_loss': 2.0907805212730173, 'avg_acc': 49.9810403283369, 'loss': 1.146816372871399}


EP_train:0:  41%|| 1411/3463 [06:36<09:33,  3.58it/s]

{'epoch': 0, 'iter': 1410, 'avg_loss': 2.084295285753457, 'avg_acc': 50.00775159461375, 'loss': 1.2042193412780762}


EP_train:0:  41%|| 1421/3463 [06:39<09:50,  3.46it/s]

{'epoch': 0, 'iter': 1420, 'avg_loss': 2.0781177347768454, 'avg_acc': 50.01099577762139, 'loss': 1.163205862045288}


EP_train:0:  41%|| 1431/3463 [06:42<09:24,  3.60it/s]

{'epoch': 0, 'iter': 1430, 'avg_loss': 2.071785831601365, 'avg_acc': 50.026205450733755, 'loss': 1.1277780532836914}


EP_train:0:  42%|| 1441/3463 [06:44<09:22,  3.59it/s]

{'epoch': 0, 'iter': 1440, 'avg_loss': 2.0657215955603214, 'avg_acc': 50.01084316446912, 'loss': 1.1900396347045898}


EP_train:0:  42%|| 1451/3463 [06:47<09:12,  3.64it/s]

{'epoch': 0, 'iter': 1450, 'avg_loss': 2.059652869657021, 'avg_acc': 50.01507580978636, 'loss': 1.1511467695236206}


EP_train:0:  42%|| 1461/3463 [06:50<09:03,  3.68it/s]

{'epoch': 0, 'iter': 1460, 'avg_loss': 2.053619984075845, 'avg_acc': 50.0, 'loss': 1.2009360790252686}


EP_train:0:  42%|| 1471/3463 [06:53<09:01,  3.68it/s]

{'epoch': 0, 'iter': 1470, 'avg_loss': 2.047599948481748, 'avg_acc': 49.97344493541808, 'loss': 1.2592476606369019}


EP_train:0:  43%|| 1481/3463 [06:55<09:01,  3.66it/s]

{'epoch': 0, 'iter': 1480, 'avg_loss': 2.041714364334508, 'avg_acc': 49.977844361917626, 'loss': 1.133382797241211}


EP_train:0:  43%|| 1491/3463 [06:58<08:53,  3.70it/s]

{'epoch': 0, 'iter': 1490, 'avg_loss': 2.0358972288793082, 'avg_acc': 49.9790409121395, 'loss': 1.141322135925293}


EP_train:0:  43%|| 1500/3463 [07:01<09:11,  3.56it/s]


KeyboardInterrupt: 

# 5) Reference

[BERT from Scratch](https://medium.com/data-and-beyond/complete-guide-to-building-bert-model-from-sratch-3e6562228891) | [BERT vs Roberta vs XLM](https://towardsdatascience.com/bert-roberta-distilbert-xlnet-which-one-to-use-3d5ab82ba5f8) | [StructBert vs Albert vs LongForm](https://towardsdatascience.com/advancing-over-bert-bigbird-convbert-dynabert-bca78a45629c) | [BART](https://medium.com/analytics-vidhya/revealing-bart-a-denoising-objective-for-pretraining-c6e8f8009564)

# 6) Context Guided BERT

[CG-BERT REPO](https://github.com/frankaging/Quasi-Attention-ABSA/blob/main/code/model/CGBERT.py)

```python
# number of context classes
context_id_map_sentihood = [
    'location - 1 - general',
    'location - 1 - price',
    'location - 1 - safety',
    'location - 1 - transit location',
    'location - 2 - general',
    'location - 2 - price',
    'location - 2 - safety',
    'location - 2 - transit location'
]
```


In [None]:
class BERTLayerNorm(nn.Module):
    def __init__(self, config, variance_epsilon=1e-12):
        super(BERTLayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(config.hidden_size))
        self.beta = nn.Parameter(torch.zeros(config.hidden_size))
        self.variance_epsilon = variance_epsilon

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.gamma * x + self.beta

In [None]:
class ContextBERTSelfAttention(nn.Module):
    def __init__(self, config):
        super(ContextBERTSelfAttention, self).__init__()
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

        # learnable context integration factors
        # enforce initialization to zero as to leave the pretrain model
        # unperturbed in the beginning
        self.context_for_q = nn.Linear(self.attention_head_size, self.attention_head_size)
        self.context_for_k = nn.Linear(self.attention_head_size, self.attention_head_size)

        self.lambda_q_context_layer = nn.Linear(self.attention_head_size, 1, bias=False)
        self.lambda_q_query_layer = nn.Linear(self.attention_head_size, 1, bias=False)
        self.lambda_k_context_layer = nn.Linear(self.attention_head_size, 1, bias=False)
        self.lambda_k_key_layer = nn.Linear(self.attention_head_size, 1, bias=False)

        # zero-centered activation function, specifically for re-arch fine tunning
        self.lambda_sig = nn.Sigmoid()
        self.quasi_act = nn.Sigmoid()

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask,
                # optional parameters for saving context information
                device=None, context_embedded=None):

        # (m, seq_len, hidden_dim)
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        # (m, num_head, seq_len, head_dim)
        mixed_query_layer = self.transpose_for_scores(mixed_query_layer)
        mixed_key_layer = self.transpose_for_scores(mixed_key_layer)

        # (m, 1, hidden_dim) --> (m, num_head, 1, head_dim)
        context_embedded = self.transpose_for_scores(context_embedded)
        context_embedded_q = self.context_for_q(context_embedded)

        # (m, num_head, 1, head_dim) --> (m, num_head, 1, 1)
        lambda_q_context = self.lambda_q_context_layer(context_embedded_q)
        # (m, num_head, seq_len, head_dim) --> (m, num_head, seq_len, 1)
        lambda_q_query = self.lambda_q_query_layer(mixed_query_layer)
        # (m, num_head, seq_len, 1)
        lambda_q = lambda_q_context + lambda_q_query
        lambda_q = self.lambda_sig(lambda_q)

        # Q_context = (1-lambda_Q) * Q + lambda_Q * Context_Q
        # K_context = (1-lambda_K) * K + lambda_K * Context_K
        # the context is shared and is the same for every head.

        # (m, num_head, seq_len, head_dim)
        contextualized_query_layer = (1 - lambda_q) * mixed_query_layer + lambda_q * context_embedded_q

        # repeat same for key
        context_embedded_k = self.context_for_k(context_embedded)
        lambda_k_context = self.lambda_k_context_layer(context_embedded_k)
        lambda_k_key = self.lambda_k_key_layer(mixed_key_layer)
        lambda_k = lambda_k_context + lambda_k_key
        lambda_k = self.lambda_sig(lambda_k)

        # (m, num_head, seq_len, head_dim)
        contextualized_key_layer = (1 - lambda_k) * mixed_key_layer + lambda_k * context_embedded_k

        ######################################################################

        # (m, num_head, seq_len, seq_len)
        attention_scores = torch.matmul(
            contextualized_query_layer, contextualized_key_layer.transpose(-1, -2)
        )
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_scores = attention_scores + attention_mask
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)

        # (m, num_head, seq_len, seq_len)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # (m, num_head, seq_len, head_dim)
        context_layer = torch.matmul(attention_probs, value_layer)

        # (m, seq_len, num_head, head_dim)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

        # (m, seq_len, hidden_dim)
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer

In [None]:
class ContextBERTEncoder(nn.Module):
    def __init__(self, config):
        super(ContextBERTEncoder, self).__init__()

        deep_context_transform_layer = nn.Linear(
            2*config.hidden_size, config.hidden_size
        )

        self.context_layer = nn.ModuleList([
            copy.deepcopy(deep_context_transform_layer) for _ in range(config.num_hidden_layers)
        ])

        layer = ContextBERTLayer(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])

    def forward(self, hidden_states, attention_mask, device=None, context_embeddings=None):

        all_encoder_layers = []
        layer_index = 0
        for layer_module in self.layer:
            deep_context_hidden = torch.cat([context_embeddings, hidden_states], dim=-1)
            deep_context_hidden = self.context_layer[layer_index](deep_context_hidden)
            deep_context_hidden += context_embeddings

            # BERT encoding
            hidden_states = layer_module(
                hidden_states, attention_mask, device, deep_context_hidden
            )
            # (n_layer, m, seq_len, hidden_dim)
            all_encoder_layers.append(hidden_states)
            layer_index += 1

        return all_encoder_layers

In [None]:
class ContextBertModel(nn.Module):
    def __init__(self, config: BertConfig):
        super(ContextBertModel, self).__init__()
        self.embeddings = BERTEmbeddings(config)
        self.encoder = ContextBERTEncoder(config)
        self.pooler = ContextBERTPooler(config)

        # context embedding
        num_target = 4
        num_aspect = 2
        self.context_embeddings = nn.Embedding(num_target*num_aspect, config.hidden_size)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None,
                device=None, context_ids=None):

        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # [batch_size, 1, 1, from_seq_length]
        # broadcast to [batch_size, num_heads, seq_length, seq_length]
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # multiply by negative big number to make it ignore during softmax
        extended_attention_mask = extended_attention_mask.float()
        extended_attention_mask = (1.0 - extended_attention_mask) * -9e9
        embedding_output = self.embeddings(input_ids, token_type_ids)

        # context embeddings
        # [batch_size, 1, context_embedding_dim]
        context_embedded = self.context_embeddings(context_ids).squeeze(dim=1)
        # [batch_size, seq_len, context_embedding_dim]
        seq_len = embedding_output.shape[1]
        context_embedding_output = torch.stack(seq_len*[context_embedded], dim=1)

        # (n_layer, m, seq_len, hidden_dim)
        all_encoder_layers = self.encoder(
            embedding_output,
            extended_attention_mask,
            device,
            context_embedding_output
        )

        sequence_output = all_encoder_layers[-1]
        pooled_output = self.pooler(sequence_output, attention_mask)
        return pooled_output