<a href="https://colab.research.google.com/github/suresh-venkate/NLP_LLM/blob/main/Transformers/BERT/BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BERT Implementation - PyTorch



## Install Required Libraries

In [1]:
%%capture
pip install transformers datasets tokenizers

In [2]:
import os
from pathlib import Path
import re
import random
import tqdm
import itertools
import math
import numpy as np

import transformers, datasets
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

## Download dataset

Download cornell movie dialogs corpus ([Link](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html))

In [3]:
%%capture
!wget http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
!unzip -qq cornell_movie_dialogs_corpus.zip
!rm cornell_movie_dialogs_corpus.zip
!mkdir datasets
!mv cornell\ movie-dialogs\ corpus/movie_conversations.txt ./datasets
!mv cornell\ movie-dialogs\ corpus/movie_lines.txt ./datasets

## Load data into memory

In [4]:
### loading all data into memory
corpus_movie_conv = './datasets/movie_conversations.txt'
corpus_movie_lines = './datasets/movie_lines.txt'
with open(corpus_movie_conv, 'r', encoding='iso-8859-1') as c:
    conv = c.readlines()
with open(corpus_movie_lines, 'r', encoding='iso-8859-1') as l:
    lines = l.readlines()

In [5]:
### split text
lines_dic = {}
for line in lines:
    objects = line.split(" +++$+++ ")
    lines_dic[objects[0]] = objects[-1]

In [6]:
### generate question answer pairs
MAX_LEN = 64
pairs = []
ind = 0
for con in conv:
    ids = eval(con.split(" +++$+++ ")[-1])
    for i in range(len(ids)):
        qa_pairs = []

        if i == len(ids) - 1:
            break

        first = lines_dic[ids[i]].strip()
        second = lines_dic[ids[i+1]].strip()

        qa_pairs.append(' '.join(first.split()[:MAX_LEN]))
        qa_pairs.append(' '.join(second.split()[:MAX_LEN]))
        pairs.append(qa_pairs)

## Wordpiece Tokenization

BERT employs a WordPiece tokenizer. To train the tokenizer, the BertWordPieceTokenizer from the transformer library is used with the steps below:



*   Saving the conversation text into multiple .txt files (with batch of N=10000)
*   Define BertWordPieceTokenizer with some parameters like **clean_text** to remove control characters, **handle_chinese_chars** to include spaces around Chinese characters, **stripe_accents** to remove accents and make é → e, ô → o, and **lowercase** to view capital and lowercase characters as equal.
*   Train the tokenizer based on the file path to .txt files with parameters like **vocab_size** defining the total number of tokens, **min_frequency** for minimum frequency for a pair of tokens to be merged, **special_tokens** defining a list of the special tokens that BERT uses, **limit_alphabet** for a maximum number of different characters, **workpieces_prefix** the prefix added to pieces of words (like ##ing).

In [7]:
### save data as txt file
os.mkdir('./data')
text_data = []
file_count = 0

In [8]:
for sample in tqdm.tqdm([x[0] for x in pairs]):
    text_data.append(sample)

    # once we hit the 10K mark, save to file
    if len(text_data) == 10000:
        with open(f'./data/text_{file_count}.txt', 'w', encoding='utf-8') as fp:
            fp.write('\n'.join(text_data))
        text_data = []
        file_count += 1

100%|██████████| 221616/221616 [00:00<00:00, 1579046.85it/s]


In [9]:
paths = [str(x) for x in Path('./data').glob('**/*.txt')]

The code below is used to train a custom tokenizer for a BERT model using the **BertWordPieceTokenizer** class. This tokenizer is designed to tokenize text data into subword tokens suitable for use with BERT-based models.

First, we define the tokenizer with the following parameters:

*   **clean_text**=True: This parameter specifies whether the text data should be cleaned before tokenization. When set to True, it performs text cleaning operations such as removing extra spaces and unwanted characters.
*   **handle_chinese_chars**=False: This parameter specifies whether the tokenizer should handle Chinese characters specially. If set to True, the tokenizer would treat Chinese characters as individual tokens; otherwise, it would treat them as parts of larger tokens.
*   **strip_accents**=False: This parameter specifies whether accents should be stripped from characters. When set to True, it removes accents from characters before tokenization
*   **lowercase**=True: This parameter specifies whether the text should be converted to lowercase before tokenization. When set to True, it converts all text to lowercase.

In [10]:
tokenizer = BertWordPieceTokenizer(
    clean_text=True,
    handle_chinese_chars=False,
    strip_accents=False,
    lowercase=True
)

Then we train the tokenizer with the following parameters:

*   **files**=paths: This parameter specifies the list of file paths containing the text data used for training the tokenizer.
*   **vocab_size**=30_000: This parameter specifies the maximum size of the vocabulary (number of tokens) that the tokenizer will produce. In this case, the vocabulary size is set to 30,000 tokens.
* **min_frequency**=5: This parameter specifies the minimum frequency threshold for tokens to be included in the vocabulary. Tokens that occur less frequently than this threshold will not be included in the vocabulary.
* **limit_alphabet**=1000: This parameter specifies the maximum number of characters in the alphabet used for tokenization. It limits the characters that the tokenizer considers when building the vocabulary.
* **wordpieces_prefix**='##': This parameter specifies the prefix to be used for wordpieces. Wordpieces are subword tokens produced by the tokenizer. The ## prefix indicates that a token is a continuation of a previous token.
* **special_tokens**=['[PAD]', '[CLS]', '[SEP]', '[MASK]', '[UNK]']: This parameter specifies a list of special tokens to be included in the vocabulary. These tokens have specific meanings within the BERT model architecture, such as padding token [PAD], classification token [CLS], separator token [SEP], mask token [MASK], and unknown token [UNK].

In [11]:
tokenizer.train(
    files=paths,
    vocab_size=30_000,
    min_frequency=5,
    limit_alphabet=1000,
    wordpieces_prefix='##',
    special_tokens=['[PAD]', '[CLS]', '[SEP]', '[MASK]', '[UNK]']
    )

os.mkdir('./bert-it-1')
# Save the tokenizer model to the directory './bert-it-1' with the prefix 'bert-it'
tokenizer.save_model('./bert-it-1', 'bert-it')
# Re-load the tokenizer from the previously saved model files located in the './bert-it-1' directory
# The argument './bert-it-1/bert-it-vocab.txt' specifies the path to the vocabulary file
# ('bert-it-vocab.txt') within the 'bert-it-1' directory.
# The local_files_only=True argument ensures that only locally available files are used for loading
# the tokenizer, meaning it won't attempt to download the model from the internet.
tokenizer = BertTokenizer.from_pretrained('./bert-it-1/bert-it-vocab.txt', local_files_only=True)



**Special Tokens in BERT**

* CLS stands for classification. It serves as the the Start of Sentence (SOS) and represent the meaning of the entire sentence.
* SEP serves as End of Sentence (EOS) and also the separation token between first and second sentences.
* PAD to be added into sentences so that all of them would be in equal length. During the training process, please note that the [PAD] token with id of 0 will not contribute to the gradient .
* MASK for word replacement during masked language prediction
* UNK serves as a replacement for token if it’s not being found in the tokenizer’s vocab.

## Class: BERTDataset

In [12]:
class BERTDataset(Dataset):
    def __init__(self, data_pair, tokenizer, seq_len=64):

        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.corpus_lines = len(data_pair)
        self.lines = data_pair

    def __len__(self):
        return self.corpus_lines

    def get_sent(self, index):
        '''return a random sentence pair and corresponding 'is_next' label'''
        t1, t2 = self.get_corpus_line(index)

        # negative or positive pair, for next sentence prediction
        if random.random() > 0.5:
            return t1, t2, 1 # Positive pair - Label is 1
        else:
            return t1, self.get_random_line(), 0 # Negative pair - Label is 0

    def get_corpus_line(self, item):
        '''return sentence pair'''
        return self.lines[item][0], self.lines[item][1]

    def get_random_line(self):
        '''return random single sentence'''
        return self.lines[random.randrange(len(self.lines))][1]

    def random_word(self, sentence):
        """
        Performs the random replacement of tokens in each sentence using the given tokenizer object.
        """
        tokens = sentence.split()
        output_label = []
        output = []

        # 15% of the tokens would be replaced
        for i, token in enumerate(tokens):
            prob = random.random()

            # remove cls and sep token
            token_id = self.tokenizer(token)['input_ids'][1:-1]

            if prob < 0.15:
                prob /= 0.15

                # 80% chance change token to mask token
                if prob < 0.8:
                    for i in range(len(token_id)):
                        output.append(self.tokenizer.vocab['[MASK]'])

                # 10% chance change token to random token
                elif prob < 0.9:
                    for i in range(len(token_id)):
                        output.append(random.randrange(len(self.tokenizer.vocab)))

                # 10% chance change token to current token
                else:
                    output.append(token_id)

                output_label.append(token_id)

            else:
                output.append(token_id)
                for i in range(len(token_id)):
                    output_label.append(0)

        # flattening
        # The following commands take a nested list as input, flattens it by removing one level of nesting,
        # and returns the flattened list as the output.
        output = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output]))
        output_label = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output_label]))
        assert len(output) == len(output_label) # Check whether output and output_label have the same length
        return output, output_label

    def __getitem__(self, item):

        # Step 1: get random sentence pair, either negative or positive (saved as is_next_label)
        t1, t2, is_next_label = self.get_sent(item)

        # Step 2: replace random words in sentence with mask / random words
        t1_random, t1_label = self.random_word(t1)
        t2_random, t2_label = self.random_word(t2)

        # Step 3: Adding CLS and SEP tokens to the start and end of sentences
        # Adding PAD token for labels
        t1 = [self.tokenizer.vocab['[CLS]']] + t1_random + [self.tokenizer.vocab['[SEP]']]
        t2 = t2_random + [self.tokenizer.vocab['[SEP]']]
        t1_label = [self.tokenizer.vocab['[PAD]']] + t1_label + [self.tokenizer.vocab['[PAD]']]
        t2_label = t2_label + [self.tokenizer.vocab['[PAD]']]

        # Step 4: combine sentence 1 and 2 as one input
        # adding PAD tokens to make the sentence same length as seq_len
        # Explanation of below code:
        # segment_label: This line creates a list called segment_label that contains segment labels for tokens in the input sequences.
        # It combines segment labels for two sequences (t1 and t2) by creating a list of 1s (representing the first segment)
        # with a length equal to the number of tokens in t1, and a list of 2s (representing the second segment) with a length
        # equal to the number of tokens in t2. It then concatenates these two lists and truncates the result to the
        # desired sequence length (self.seq_len).
        segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
        # bert_input: This line creates a list called bert_input by concatenating the tokens from t1 and t2,
        # and then truncates the result to the desired sequence length (self.seq_len). This list likely represents the input
        # tokens for the BERT model.
        bert_input = (t1 + t2)[:self.seq_len]
        # bert_label: This line creates a list called bert_label by concatenating the labels associated with tokens in t1 and t2,
        # and then truncates the result to the desired sequence length (self.seq_len). This list likely represents the labels
        # corresponding to the input tokens for the BERT model.
        bert_label = (t1_label + t2_label)[:self.seq_len]
        # padding: This line creates a list called padding containing [PAD] tokens (which typically represent padding tokens in
        # tokenized sequences) to fill the remaining space up to self.seq_len.
        padding = [self.tokenizer.vocab['[PAD]'] for _ in range(self.seq_len - len(bert_input))]
        # Padding the sequences: Finally, the code extends bert_input, bert_label, and segment_label lists with the padding
        # tokens to ensure that all sequences have the same length (self.seq_len). This padding step is crucial for batch processing,
        bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)

        output = {"bert_input": bert_input,
                  "bert_label": bert_label,
                  "segment_label": segment_label,
                  "is_next": is_next_label}

        return {key: torch.tensor(value) for key, value in output.items()}

In [13]:
train_data = BERTDataset(
   pairs, seq_len=MAX_LEN, tokenizer=tokenizer)
train_loader = DataLoader(
   train_data, batch_size=32, shuffle=True, pin_memory=True)
sample_data = next(iter(train_loader))

## Embeddings

In [14]:
class PositionalEmbedding(torch.nn.Module):

    def __init__(self, d_model, max_len=128):
        super().__init__()

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        for pos in range(max_len):
            # for each dimension of the each position
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))

        # include the batch size
        self.pe = pe.unsqueeze(0)
        # self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe

class BERTEmbedding(torch.nn.Module):
    """
    BERT Embedding which is consisted with under features
        1. TokenEmbedding : normal embedding matrix
        2. PositionalEmbedding : adding positional information using sin, cos
        2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2)
        sum of all these features are output of BERTEmbedding
    """

    def __init__(self, vocab_size, embed_size, seq_len=64, dropout=0.1):
        """
        :param vocab_size: total vocab size
        :param embed_size: embedding size of token embedding
        :param dropout: dropout rate
        """

        super().__init__()
        self.embed_size = embed_size
        # (m, seq_len) --> (m, seq_len, embed_size)
        # padding_idx is not updated during training, remains as fixed pad (0)
        self.token = torch.nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.segment = torch.nn.Embedding(3, embed_size, padding_idx=0)
        self.position = PositionalEmbedding(d_model=embed_size, max_len=seq_len)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, sequence, segment_label):
        x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
        return self.dropout(x)

## BERT Model

In [15]:
### attention layers
class MultiHeadedAttention(torch.nn.Module):

    def __init__(self, heads, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()

        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = torch.nn.Dropout(dropout)

        self.query = torch.nn.Linear(d_model, d_model)
        self.key = torch.nn.Linear(d_model, d_model)
        self.value = torch.nn.Linear(d_model, d_model)
        self.output_linear = torch.nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask):
        """
        query, key, value of shape: (batch_size, max_len, d_model)
        mask of shape: (batch_size, 1, 1, max_words)
        """
        # (batch_size, max_len, d_model)
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        # (batch_size, max_len, d_model) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)

        # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(query.size(-1))

        # fill 0 mask with super small number so it wont affect the softmax weight
        # (batch_size, h, max_len, max_len)
        scores = scores.masked_fill(mask == 0, -1e9)

        # (batch_size, h, max_len, max_len)
        # softmax to put attention weight for all non-pad tokens
        # max_len X max_len matrix of attention
        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)

        # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(weights, value)

        # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, d_model)
        context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)

        # (batch_size, max_len, d_model)
        return self.output_linear(context)

class FeedForward(torch.nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, middle_dim=2048, dropout=0.1):
        super(FeedForward, self).__init__()

        self.fc1 = torch.nn.Linear(d_model, middle_dim)
        self.fc2 = torch.nn.Linear(middle_dim, d_model)
        self.dropout = torch.nn.Dropout(dropout)
        self.activation = torch.nn.GELU()

    def forward(self, x):
        out = self.activation(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

class EncoderLayer(torch.nn.Module):
    def __init__(
        self,
        d_model=768,
        heads=12,
        feed_forward_hidden=768 * 4,
        dropout=0.1
        ):
        super(EncoderLayer, self).__init__()
        self.layernorm = torch.nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadedAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model, middle_dim=feed_forward_hidden)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, embeddings, mask):
        # embeddings: (batch_size, max_len, d_model)
        # encoder mask: (batch_size, 1, 1, max_len)
        # result: (batch_size, max_len, d_model)
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        # residual layer
        interacted = self.layernorm(interacted + embeddings)
        # bottleneck
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

In [16]:
class BERT(torch.nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """

    def __init__(self, vocab_size, d_model=768, n_layers=12, heads=12, dropout=0.1):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """

        super().__init__()
        self.d_model = d_model
        self.n_layers = n_layers
        self.heads = heads

        # paper noted they used 4 * hidden_size for ff_network_hidden_size
        self.feed_forward_hidden = d_model * 4

        # embedding for BERT, sum of positional, segment, token embeddings
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=d_model)

        # multi-layers transformer blocks, deep network
        self.encoder_blocks = torch.nn.ModuleList(
            [EncoderLayer(d_model, heads, d_model * 4, dropout) for _ in range(n_layers)])

    def forward(self, x, segment_info):
        # attention masking for padded token
        # (batch_size, 1, seq_len, seq_len)
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

        # embedding the indexed sequence to sequence of vectors
        x = self.embedding(x, segment_info)

        # running over multiple transformer blocks
        for encoder in self.encoder_blocks:
            x = encoder.forward(x, mask)
        return x

class NextSentencePrediction(torch.nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """

    def __init__(self, hidden):
        """
        :param hidden: BERT model output size
        """
        super().__init__()
        self.linear = torch.nn.Linear(hidden, 2)
        self.softmax = torch.nn.LogSoftmax(dim=-1)

    def forward(self, x):
        # use only the first token which is the [CLS]
        return self.softmax(self.linear(x[:, 0]))

class MaskedLanguageModel(torch.nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self.linear = torch.nn.Linear(hidden, vocab_size)
        self.softmax = torch.nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))

class BERTLM(torch.nn.Module):
    """
    BERT Language Model
    Next Sentence Prediction Model + Masked Language Model
    """

    def __init__(self, bert: BERT, vocab_size):
        """
        :param bert: BERT model which should be trained
        :param vocab_size: total vocab size for masked_lm
        """

        super().__init__()
        self.bert = bert
        self.next_sentence = NextSentencePrediction(self.bert.d_model)
        self.mask_lm = MaskedLanguageModel(self.bert.d_model, vocab_size)

    def forward(self, x, segment_label):
        x = self.bert(x, segment_label)
        return self.next_sentence(x), self.mask_lm(x)

## Optimizer

In [17]:
class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(d_model, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

## BERT Trainer

In [18]:
class BERTTrainer:
    def __init__(
        self,
        model,
        train_dataloader,
        test_dataloader=None,
        lr= 1e-4,
        weight_decay=0.01,
        betas=(0.9, 0.999),
        warmup_steps=10000,
        log_freq=10,
        device='cuda'
        ):

        self.device = device
        self.model = model
        self.train_data = train_dataloader
        self.test_data = test_dataloader

        # Setting the Adam optimizer with hyper-param
        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        self.optim_schedule = ScheduledOptim(
            self.optim, self.model.bert.d_model, n_warmup_steps=warmup_steps
            )

        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = torch.nn.NLLLoss(ignore_index=0)
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):

        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        mode = "train" if train else "test"

        # progress bar
        data_iter = tqdm.tqdm(
            enumerate(data_loader),
            desc="EP_%s:%d" % (mode, epoch),
            total=len(data_loader),
            bar_format="{l_bar}{r_bar}"
        )

        for i, data in data_iter:

            # 0. batch_data will be sent into the device(GPU or cpu)
            data = {key: value.to(self.device) for key, value in data.items()}

            # 1. forward the next_sentence_prediction and masked_lm model
            next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])

            # 2-1. NLL(negative log likelihood) loss of is_next classification result
            next_loss = self.criterion(next_sent_output, data["is_next"])

            # 2-2. NLLLoss of predicting masked token word
            # transpose to (m, vocab_size, seq_len) vs (m, seq_len)
            # criterion(mask_lm_output.view(-1, mask_lm_output.size(-1)), data["bert_label"].view(-1))
            mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])

            # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
            loss = next_loss + mask_loss

            # 3. backward and optimization only in train
            if train:
                self.optim_schedule.zero_grad()
                loss.backward()
                self.optim_schedule.step_and_update_lr()

            # next sentence prediction accuracy
            correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            total_element += data["is_next"].nelement()

            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))
        print(
            f"EP{epoch}, {mode}: \
            avg_loss={avg_loss / len(data_iter)}, \
            total_acc={total_correct * 100.0 / total_element}"
        )

## Test Run

In [None]:
'''test run'''

train_data = BERTDataset(
   pairs, seq_len=MAX_LEN, tokenizer=tokenizer)

train_loader = DataLoader(
   train_data, batch_size=32, shuffle=True, pin_memory=True)

bert_model = BERT(
  vocab_size=len(tokenizer.vocab),
  d_model=768,
  n_layers=2,
  heads=12,
  dropout=0.1
)

bert_lm = BERTLM(bert_model, len(tokenizer.vocab))
bert_trainer = BERTTrainer(bert_lm, train_loader, device='cpu')
epochs = 20

for epoch in range(epochs):
  bert_trainer.train(epoch)

Total Parameters: 46699434


EP_train:0:   0%|| 1/6926 [00:11<23:02:26, 11.98s/it]

{'epoch': 0, 'iter': 0, 'avg_loss': 11.229270935058594, 'avg_acc': 50.0, 'loss': 11.229270935058594}


EP_train:0:   0%|| 11/6926 [01:21<13:29:18,  7.02s/it]

{'epoch': 0, 'iter': 10, 'avg_loss': 11.104168024930088, 'avg_acc': 47.159090909090914, 'loss': 11.031635284423828}


EP_train:0:   0%|| 21/6926 [02:32<13:46:37,  7.18s/it]

{'epoch': 0, 'iter': 20, 'avg_loss': 11.026497613816034, 'avg_acc': 49.25595238095239, 'loss': 10.898488998413086}


EP_train:0:   0%|| 31/6926 [03:45<14:34:42,  7.61s/it]

{'epoch': 0, 'iter': 30, 'avg_loss': 10.944456531155494, 'avg_acc': 48.58870967741936, 'loss': 10.688058853149414}


EP_train:0:   1%|| 41/6926 [04:54<13:25:58,  7.02s/it]

{'epoch': 0, 'iter': 40, 'avg_loss': 10.83992834788997, 'avg_acc': 48.78048780487805, 'loss': 10.480600357055664}


EP_train:0:   1%|| 51/6926 [06:05<13:34:31,  7.11s/it]

{'epoch': 0, 'iter': 50, 'avg_loss': 10.738482194788316, 'avg_acc': 48.65196078431372, 'loss': 10.25341796875}


EP_train:0:   1%|| 61/6926 [07:13<12:56:38,  6.79s/it]

{'epoch': 0, 'iter': 60, 'avg_loss': 10.64740951725694, 'avg_acc': 49.07786885245902, 'loss': 10.088589668273926}


EP_train:0:   1%|| 71/6926 [08:22<13:04:13,  6.86s/it]

{'epoch': 0, 'iter': 70, 'avg_loss': 10.569779261736803, 'avg_acc': 48.811619718309856, 'loss': 10.045145034790039}


EP_train:0:   1%|| 81/6926 [09:31<13:03:22,  6.87s/it]

{'epoch': 0, 'iter': 80, 'avg_loss': 10.499472076510205, 'avg_acc': 49.382716049382715, 'loss': 9.978952407836914}


EP_train:0:   1%|| 91/6926 [10:39<12:41:52,  6.69s/it]

{'epoch': 0, 'iter': 90, 'avg_loss': 10.432331200484391, 'avg_acc': 49.55357142857143, 'loss': 9.906965255737305}


EP_train:0:   1%|| 101/6926 [11:48<13:10:16,  6.95s/it]

{'epoch': 0, 'iter': 100, 'avg_loss': 10.374276689963766, 'avg_acc': 49.38118811881188, 'loss': 9.714226722717285}


EP_train:0:   2%|| 111/6926 [12:57<12:54:55,  6.82s/it]

{'epoch': 0, 'iter': 110, 'avg_loss': 10.313651643357835, 'avg_acc': 49.38063063063063, 'loss': 9.679250717163086}


EP_train:0:   2%|| 121/6926 [14:08<13:52:25,  7.34s/it]

{'epoch': 0, 'iter': 120, 'avg_loss': 10.250272088799594, 'avg_acc': 49.40599173553719, 'loss': 9.6199312210083}


EP_train:0:   2%|| 131/6926 [15:21<13:32:21,  7.17s/it]

{'epoch': 0, 'iter': 130, 'avg_loss': 10.19408081142047, 'avg_acc': 49.68988549618321, 'loss': 9.59961986541748}


EP_train:0:   2%|| 141/6926 [16:29<12:51:48,  6.83s/it]

{'epoch': 0, 'iter': 140, 'avg_loss': 10.12808477794025, 'avg_acc': 49.734042553191486, 'loss': 9.179484367370605}


EP_train:0:   2%|| 151/6926 [17:37<12:37:39,  6.71s/it]

{'epoch': 0, 'iter': 150, 'avg_loss': 10.065421893896646, 'avg_acc': 49.56539735099338, 'loss': 9.235438346862793}


EP_train:0:   2%|| 161/6926 [18:47<13:19:59,  7.10s/it]

{'epoch': 0, 'iter': 160, 'avg_loss': 10.007748657131787, 'avg_acc': 49.398291925465834, 'loss': 9.066731452941895}


EP_train:0:   2%|| 171/6926 [19:55<12:55:57,  6.89s/it]

{'epoch': 0, 'iter': 170, 'avg_loss': 9.948594673335204, 'avg_acc': 49.470029239766085, 'loss': 9.129803657531738}


EP_train:0:   3%|| 181/6926 [21:04<12:55:42,  6.90s/it]

{'epoch': 0, 'iter': 180, 'avg_loss': 9.88217057180668, 'avg_acc': 49.37845303867403, 'loss': 8.68496036529541}


EP_train:0:   3%|| 191/6926 [22:14<13:19:25,  7.12s/it]

{'epoch': 0, 'iter': 190, 'avg_loss': 9.82303244905322, 'avg_acc': 49.36191099476439, 'loss': 8.592886924743652}


EP_train:0:   3%|| 201/6926 [23:23<12:36:02,  6.75s/it]

{'epoch': 0, 'iter': 200, 'avg_loss': 9.762187023067948, 'avg_acc': 49.56467661691542, 'loss': 8.337730407714844}


EP_train:0:   3%|| 211/6926 [24:32<13:02:59,  7.00s/it]

{'epoch': 0, 'iter': 210, 'avg_loss': 9.703299036523177, 'avg_acc': 49.58530805687204, 'loss': 8.333684921264648}


EP_train:0:   3%|| 221/6926 [25:43<13:04:29,  7.02s/it]

{'epoch': 0, 'iter': 220, 'avg_loss': 9.64111442695376, 'avg_acc': 49.57579185520362, 'loss': 8.177570343017578}


EP_train:0:   3%|| 231/6926 [26:53<12:53:11,  6.93s/it]

{'epoch': 0, 'iter': 230, 'avg_loss': 9.586802918157536, 'avg_acc': 49.445346320346324, 'loss': 8.646244049072266}


EP_train:0:   3%|| 241/6926 [28:02<12:43:37,  6.85s/it]

{'epoch': 0, 'iter': 240, 'avg_loss': 9.528608076799955, 'avg_acc': 49.50726141078838, 'loss': 8.036192893981934}


EP_train:0:   4%|| 251/6926 [29:12<12:50:20,  6.92s/it]

{'epoch': 0, 'iter': 250, 'avg_loss': 9.477128724177994, 'avg_acc': 49.29033864541832, 'loss': 8.161190032958984}


EP_train:0:   4%|| 261/6926 [30:20<12:30:37,  6.76s/it]

{'epoch': 0, 'iter': 260, 'avg_loss': 9.421149851261884, 'avg_acc': 49.377394636015325, 'loss': 8.445257186889648}


EP_train:0:   4%|| 271/6926 [31:29<12:39:33,  6.85s/it]

{'epoch': 0, 'iter': 270, 'avg_loss': 9.367200758184454, 'avg_acc': 49.40036900369004, 'loss': 7.755031585693359}


EP_train:0:   4%|| 281/6926 [32:39<12:46:33,  6.92s/it]

{'epoch': 0, 'iter': 280, 'avg_loss': 9.318998716904176, 'avg_acc': 49.31049822064057, 'loss': 7.907980918884277}


EP_train:0:   4%|| 291/6926 [33:47<12:27:18,  6.76s/it]

{'epoch': 0, 'iter': 290, 'avg_loss': 9.268477189172174, 'avg_acc': 49.24828178694158, 'loss': 7.702703475952148}


EP_train:0:   4%|| 301/6926 [34:57<12:47:28,  6.95s/it]

{'epoch': 0, 'iter': 300, 'avg_loss': 9.215609217798987, 'avg_acc': 49.138289036544855, 'loss': 7.809710502624512}


EP_train:0:   4%|| 311/6926 [36:06<12:18:13,  6.70s/it]

{'epoch': 0, 'iter': 310, 'avg_loss': 9.165657961866863, 'avg_acc': 49.20618971061093, 'loss': 7.7310380935668945}


EP_train:0:   5%|| 321/6926 [37:15<12:45:15,  6.95s/it]

{'epoch': 0, 'iter': 320, 'avg_loss': 9.117040769333409, 'avg_acc': 49.21144859813084, 'loss': 7.373783111572266}


EP_train:0:   5%|| 331/6926 [38:26<13:25:36,  7.33s/it]

{'epoch': 0, 'iter': 330, 'avg_loss': 9.071285486941612, 'avg_acc': 49.121978851963746, 'loss': 8.156224250793457}


EP_train:0:   5%|| 341/6926 [39:37<12:45:46,  6.98s/it]

{'epoch': 0, 'iter': 340, 'avg_loss': 9.028866549740789, 'avg_acc': 49.08357771260997, 'loss': 7.62197208404541}


EP_train:0:   5%|| 351/6926 [40:47<12:35:25,  6.89s/it]

{'epoch': 0, 'iter': 350, 'avg_loss': 8.98063139086775, 'avg_acc': 49.074074074074076, 'loss': 7.387594223022461}


EP_train:0:   5%|| 361/6926 [42:00<13:09:45,  7.22s/it]

{'epoch': 0, 'iter': 360, 'avg_loss': 8.937294472617786, 'avg_acc': 49.134349030470915, 'loss': 7.471529960632324}


EP_train:0:   5%|| 371/6926 [43:10<12:57:59,  7.12s/it]

{'epoch': 0, 'iter': 370, 'avg_loss': 8.90279979525872, 'avg_acc': 49.33456873315364, 'loss': 7.588113784790039}


EP_train:0:   6%|| 381/6926 [44:22<13:31:39,  7.44s/it]

{'epoch': 0, 'iter': 380, 'avg_loss': 8.860997390246453, 'avg_acc': 49.286417322834644, 'loss': 7.452469348907471}


EP_train:0:   6%|| 391/6926 [45:36<13:27:14,  7.41s/it]

{'epoch': 0, 'iter': 390, 'avg_loss': 8.822794258137188, 'avg_acc': 49.20076726342711, 'loss': 7.28240966796875}


EP_train:0:   6%|| 401/6926 [46:48<13:27:06,  7.42s/it]

{'epoch': 0, 'iter': 400, 'avg_loss': 8.784269492227835, 'avg_acc': 49.12718204488778, 'loss': 7.207321643829346}


EP_train:0:   6%|| 411/6926 [48:02<13:13:21,  7.31s/it]

{'epoch': 0, 'iter': 410, 'avg_loss': 8.74794492582335, 'avg_acc': 49.10279805352798, 'loss': 7.248763084411621}


EP_train:0:   6%|| 421/6926 [49:15<13:04:01,  7.23s/it]

{'epoch': 0, 'iter': 420, 'avg_loss': 8.711120992828151, 'avg_acc': 49.049881235154395, 'loss': 7.343123912811279}


EP_train:0:   6%|| 431/6926 [50:23<12:16:04,  6.80s/it]

{'epoch': 0, 'iter': 430, 'avg_loss': 8.670904482600307, 'avg_acc': 49.10092807424594, 'loss': 7.223913669586182}


EP_train:0:   6%|| 441/6926 [51:33<12:37:16,  7.01s/it]

{'epoch': 0, 'iter': 440, 'avg_loss': 8.63667128253686, 'avg_acc': 48.97250566893424, 'loss': 7.172632694244385}


EP_train:0:   7%|| 451/6926 [52:41<12:20:59,  6.87s/it]

{'epoch': 0, 'iter': 450, 'avg_loss': 8.603705856593908, 'avg_acc': 49.07843680709534, 'loss': 6.8051981925964355}


EP_train:0:   7%|| 461/6926 [53:49<12:00:39,  6.69s/it]

{'epoch': 0, 'iter': 460, 'avg_loss': 8.569726867427537, 'avg_acc': 49.06453362255965, 'loss': 7.011475086212158}


EP_train:0:   7%|| 471/6926 [54:59<12:32:10,  6.99s/it]

{'epoch': 0, 'iter': 470, 'avg_loss': 8.536079014942144, 'avg_acc': 49.110934182590235, 'loss': 7.309597492218018}


EP_train:0:   7%|| 481/6926 [56:07<12:12:00,  6.81s/it]

{'epoch': 0, 'iter': 480, 'avg_loss': 8.502417079624168, 'avg_acc': 49.109927234927234, 'loss': 6.878180980682373}


EP_train:0:   7%|| 491/6926 [57:15<11:51:07,  6.63s/it]

{'epoch': 0, 'iter': 490, 'avg_loss': 8.468869794896081, 'avg_acc': 49.210794297352344, 'loss': 6.820246696472168}


EP_train:0:   7%|| 501/6926 [58:23<12:10:49,  6.82s/it]

{'epoch': 0, 'iter': 500, 'avg_loss': 8.434950436422687, 'avg_acc': 49.189121756487026, 'loss': 6.88431453704834}


EP_train:0:   7%|| 511/6926 [59:31<12:00:08,  6.74s/it]

{'epoch': 0, 'iter': 510, 'avg_loss': 8.404115230836513, 'avg_acc': 49.11325831702544, 'loss': 6.931926250457764}


EP_train:0:   8%|| 521/6926 [1:00:38<11:48:03,  6.63s/it]

{'epoch': 0, 'iter': 520, 'avg_loss': 8.374184082199653, 'avg_acc': 49.05830134357006, 'loss': 7.044764518737793}


EP_train:0:   8%|| 531/6926 [1:01:48<12:15:25,  6.90s/it]

{'epoch': 0, 'iter': 530, 'avg_loss': 8.344793249658272, 'avg_acc': 49.03483992467043, 'loss': 6.98124361038208}


EP_train:0:   8%|| 541/6926 [1:02:57<12:21:03,  6.96s/it]

{'epoch': 0, 'iter': 540, 'avg_loss': 8.315543729143972, 'avg_acc': 49.20864140480592, 'loss': 6.712698936462402}


EP_train:0:   8%|| 551/6926 [1:04:08<12:32:13,  7.08s/it]

{'epoch': 0, 'iter': 550, 'avg_loss': 8.286369409405385, 'avg_acc': 49.17196007259528, 'loss': 7.077779769897461}


EP_train:0:   8%|| 561/6926 [1:05:20<12:35:15,  7.12s/it]

{'epoch': 0, 'iter': 560, 'avg_loss': 8.256500496583826, 'avg_acc': 49.103163992869874, 'loss': 6.8565521240234375}


EP_train:0:   8%|| 571/6926 [1:06:33<13:02:06,  7.38s/it]

{'epoch': 0, 'iter': 570, 'avg_loss': 8.227826123479787, 'avg_acc': 49.09150612959719, 'loss': 6.297884464263916}


EP_train:0:   8%|| 581/6926 [1:07:42<12:03:07,  6.84s/it]

{'epoch': 0, 'iter': 580, 'avg_loss': 8.200205147163798, 'avg_acc': 49.058734939759034, 'loss': 6.905010223388672}


EP_train:0:   9%|| 591/6926 [1:08:52<12:19:13,  7.00s/it]

{'epoch': 0, 'iter': 590, 'avg_loss': 8.171085003467178, 'avg_acc': 49.10109983079526, 'loss': 6.466160297393799}


EP_train:0:   9%|| 601/6926 [1:10:00<12:00:30,  6.83s/it]

{'epoch': 0, 'iter': 600, 'avg_loss': 8.142924584088826, 'avg_acc': 49.17325291181364, 'loss': 6.203374862670898}


EP_train:0:   9%|| 611/6926 [1:11:07<11:41:10,  6.66s/it]

{'epoch': 0, 'iter': 610, 'avg_loss': 8.11574821659469, 'avg_acc': 49.22770049099836, 'loss': 6.315889358520508}


EP_train:0:   9%|| 621/6926 [1:12:17<12:09:42,  6.94s/it]

{'epoch': 0, 'iter': 620, 'avg_loss': 8.090824965693525, 'avg_acc': 49.25523349436393, 'loss': 6.390105247497559}


EP_train:0:   9%|| 631/6926 [1:13:25<11:45:07,  6.72s/it]

{'epoch': 0, 'iter': 630, 'avg_loss': 8.06238240334198, 'avg_acc': 49.24227416798732, 'loss': 6.848438739776611}


EP_train:0:   9%|| 641/6926 [1:14:32<11:33:28,  6.62s/it]

{'epoch': 0, 'iter': 640, 'avg_loss': 8.03710368270993, 'avg_acc': 49.27847113884555, 'loss': 6.730952739715576}


EP_train:0:   9%|| 651/6926 [1:15:41<12:05:40,  6.94s/it]

{'epoch': 0, 'iter': 650, 'avg_loss': 8.014080607946019, 'avg_acc': 49.27995391705069, 'loss': 6.294173240661621}


EP_train:0:  10%|| 661/6926 [1:16:49<11:43:52,  6.74s/it]

{'epoch': 0, 'iter': 660, 'avg_loss': 7.990376603044288, 'avg_acc': 49.31448562783661, 'loss': 6.609847545623779}


EP_train:0:  10%|| 671/6926 [1:17:57<11:45:39,  6.77s/it]

{'epoch': 0, 'iter': 670, 'avg_loss': 7.966337990654208, 'avg_acc': 49.32470193740686, 'loss': 6.512466907501221}


EP_train:0:  10%|| 681/6926 [1:19:05<11:58:22,  6.90s/it]

{'epoch': 0, 'iter': 680, 'avg_loss': 7.943580927127608, 'avg_acc': 49.38968428781204, 'loss': 6.15503454208374}


EP_train:0:  10%|| 691/6926 [1:20:13<11:24:47,  6.59s/it]

{'epoch': 0, 'iter': 690, 'avg_loss': 7.923071948905757, 'avg_acc': 49.38042691751085, 'loss': 6.355175018310547}


EP_train:0:  10%|| 701/6926 [1:21:21<11:51:54,  6.86s/it]

{'epoch': 0, 'iter': 700, 'avg_loss': 7.900251528675988, 'avg_acc': 49.331312410841655, 'loss': 6.1471428871154785}


EP_train:0:  10%|| 711/6926 [1:22:29<11:49:25,  6.85s/it]

{'epoch': 0, 'iter': 710, 'avg_loss': 7.8816061214388, 'avg_acc': 49.287974683544306, 'loss': 6.633476257324219}


EP_train:0:  10%|| 721/6926 [1:23:37<11:23:04,  6.61s/it]

{'epoch': 0, 'iter': 720, 'avg_loss': 7.859292478071999, 'avg_acc': 49.29785020804438, 'loss': 5.941165924072266}


EP_train:0:  11%|| 731/6926 [1:24:44<11:41:58,  6.80s/it]

{'epoch': 0, 'iter': 730, 'avg_loss': 7.838034547630968, 'avg_acc': 49.27753077975376, 'loss': 6.343728542327881}


EP_train:0:  11%|| 741/6926 [1:25:52<12:04:00,  7.02s/it]

{'epoch': 0, 'iter': 740, 'avg_loss': 7.816340337076329, 'avg_acc': 49.2914979757085, 'loss': 6.437711715698242}


EP_train:0:  11%|| 751/6926 [1:27:00<11:28:13,  6.69s/it]

{'epoch': 0, 'iter': 750, 'avg_loss': 7.795295140714684, 'avg_acc': 49.32173768308921, 'loss': 6.4720563888549805}


EP_train:0:  11%|| 761/6926 [1:28:07<11:37:59,  6.79s/it]

{'epoch': 0, 'iter': 760, 'avg_loss': 7.77635867210318, 'avg_acc': 49.3511826544021, 'loss': 6.63962984085083}


EP_train:0:  11%|| 771/6926 [1:29:14<11:28:55,  6.72s/it]

{'epoch': 0, 'iter': 770, 'avg_loss': 7.7550934700031995, 'avg_acc': 49.371757457846954, 'loss': 6.039074420928955}


EP_train:0:  11%|| 781/6926 [1:30:22<11:28:46,  6.73s/it]

{'epoch': 0, 'iter': 780, 'avg_loss': 7.735177911961125, 'avg_acc': 49.34379001280409, 'loss': 6.19773006439209}


EP_train:0:  11%|| 791/6926 [1:31:29<11:23:46,  6.69s/it]

{'epoch': 0, 'iter': 790, 'avg_loss': 7.714418352177713, 'avg_acc': 49.352085967130215, 'loss': 5.585363388061523}


EP_train:0:  12%|| 801/6926 [1:32:35<11:19:18,  6.65s/it]

{'epoch': 0, 'iter': 800, 'avg_loss': 7.695127703872662, 'avg_acc': 49.39918851435706, 'loss': 6.060218811035156}


EP_train:0:  12%|| 811/6926 [1:33:44<12:05:47,  7.12s/it]

{'epoch': 0, 'iter': 810, 'avg_loss': 7.678602425002581, 'avg_acc': 49.418156596794084, 'loss': 6.204774379730225}


EP_train:0:  12%|| 821/6926 [1:34:52<11:24:38,  6.73s/it]

{'epoch': 0, 'iter': 820, 'avg_loss': 7.660406262517411, 'avg_acc': 49.387180267965896, 'loss': 6.442275524139404}


EP_train:0:  12%|| 831/6926 [1:35:57<11:01:39,  6.51s/it]

{'epoch': 0, 'iter': 830, 'avg_loss': 7.641855226670505, 'avg_acc': 49.3531889290012, 'loss': 6.224607467651367}


EP_train:0:  12%|| 841/6926 [1:37:04<11:03:46,  6.54s/it]

{'epoch': 0, 'iter': 840, 'avg_loss': 7.624950363576483, 'avg_acc': 49.26055291319857, 'loss': 6.527813911437988}


EP_train:0:  12%|| 851/6926 [1:38:12<11:19:11,  6.71s/it]

{'epoch': 0, 'iter': 850, 'avg_loss': 7.610174274892841, 'avg_acc': 49.30596357226792, 'loss': 6.793323993682861}


EP_train:0:  12%|| 861/6926 [1:39:18<11:09:56,  6.63s/it]

{'epoch': 0, 'iter': 860, 'avg_loss': 7.594536004858427, 'avg_acc': 49.32854239256678, 'loss': 6.044194221496582}


EP_train:0:  13%|| 871/6926 [1:40:25<11:07:25,  6.61s/it]

{'epoch': 0, 'iter': 870, 'avg_loss': 7.579132640676575, 'avg_acc': 49.34701492537313, 'loss': 6.153696060180664}


EP_train:0:  13%|| 881/6926 [1:41:34<11:21:45,  6.77s/it]

{'epoch': 0, 'iter': 880, 'avg_loss': 7.562884536963992, 'avg_acc': 49.36506810442679, 'loss': 5.613094329833984}


EP_train:0:  13%|| 891/6926 [1:42:40<11:08:17,  6.64s/it]

{'epoch': 0, 'iter': 890, 'avg_loss': 7.546372229001338, 'avg_acc': 49.38973063973064, 'loss': 5.862337589263916}


EP_train:0:  13%|| 901/6926 [1:43:47<11:20:02,  6.77s/it]

{'epoch': 0, 'iter': 900, 'avg_loss': 7.52985909088338, 'avg_acc': 49.36182019977802, 'loss': 6.069211483001709}


EP_train:0:  13%|| 911/6926 [1:44:56<11:29:21,  6.88s/it]

{'epoch': 0, 'iter': 910, 'avg_loss': 7.513922600269841, 'avg_acc': 49.37568605927552, 'loss': 6.128734111785889}


EP_train:0:  13%|| 921/6926 [1:46:02<11:13:51,  6.73s/it]

{'epoch': 0, 'iter': 920, 'avg_loss': 7.497400721302509, 'avg_acc': 49.39942996742671, 'loss': 5.859554767608643}


EP_train:0:  13%|| 931/6926 [1:47:09<10:56:39,  6.57s/it]

{'epoch': 0, 'iter': 930, 'avg_loss': 7.482209778241773, 'avg_acc': 49.41595059076262, 'loss': 6.493114471435547}


EP_train:0:  14%|| 941/6926 [1:48:18<11:34:23,  6.96s/it]

{'epoch': 0, 'iter': 940, 'avg_loss': 7.465962304065635, 'avg_acc': 49.42547821466525, 'loss': 6.153231143951416}


EP_train:0:  14%|| 951/6926 [1:49:24<11:10:28,  6.73s/it]

{'epoch': 0, 'iter': 950, 'avg_loss': 7.451390663030646, 'avg_acc': 49.480809674027334, 'loss': 6.15553092956543}


EP_train:0:  14%|| 961/6926 [1:50:30<10:56:57,  6.61s/it]

{'epoch': 0, 'iter': 960, 'avg_loss': 7.436933434096385, 'avg_acc': 49.46344953173777, 'loss': 6.214972496032715}


EP_train:0:  14%|| 971/6926 [1:51:40<11:37:17,  7.03s/it]

{'epoch': 0, 'iter': 970, 'avg_loss': 7.422545943274925, 'avg_acc': 49.472193614830076, 'loss': 6.190056324005127}


EP_train:0:  14%|| 981/6926 [1:52:46<11:01:24,  6.68s/it]

{'epoch': 0, 'iter': 980, 'avg_loss': 7.409824494800315, 'avg_acc': 49.4361620795107, 'loss': 6.491195201873779}


EP_train:0:  14%|| 991/6926 [1:53:53<10:55:19,  6.62s/it]

{'epoch': 0, 'iter': 990, 'avg_loss': 7.395593977118837, 'avg_acc': 49.400857719475276, 'loss': 6.070124626159668}


EP_train:0:  14%|| 1001/6926 [1:55:02<11:27:33,  6.96s/it]

{'epoch': 0, 'iter': 1000, 'avg_loss': 7.382440793764341, 'avg_acc': 49.43181818181818, 'loss': 5.893037796020508}


EP_train:0:  15%|| 1011/6926 [1:56:11<11:17:05,  6.87s/it]

{'epoch': 0, 'iter': 1010, 'avg_loss': 7.36883291552022, 'avg_acc': 49.44362017804154, 'loss': 5.879763603210449}


EP_train:0:  15%|| 1021/6926 [1:57:18<10:55:01,  6.66s/it]

{'epoch': 0, 'iter': 1020, 'avg_loss': 7.355792645259188, 'avg_acc': 49.464373163565135, 'loss': 6.3923540115356445}


EP_train:0:  15%|| 1031/6926 [1:58:27<11:27:33,  7.00s/it]

{'epoch': 0, 'iter': 1030, 'avg_loss': 7.34320897020031, 'avg_acc': 49.46047526673133, 'loss': 6.098097324371338}


EP_train:0:  15%|| 1041/6926 [1:59:35<11:07:44,  6.81s/it]

{'epoch': 0, 'iter': 1040, 'avg_loss': 7.328962465299081, 'avg_acc': 49.47466378482229, 'loss': 5.579588413238525}


EP_train:0:  15%|| 1051/6926 [2:00:43<10:52:28,  6.66s/it]

{'epoch': 0, 'iter': 1050, 'avg_loss': 7.316214513370358, 'avg_acc': 49.476688867745004, 'loss': 6.321515083312988}


EP_train:0:  15%|| 1061/6926 [2:01:52<11:46:59,  7.23s/it]

{'epoch': 0, 'iter': 1060, 'avg_loss': 7.3051965602943065, 'avg_acc': 49.50812912346842, 'loss': 6.5473408699035645}


EP_train:0:  15%|| 1071/6926 [2:03:01<11:09:54,  6.86s/it]

{'epoch': 0, 'iter': 1070, 'avg_loss': 7.293460178108108, 'avg_acc': 49.49813258636788, 'loss': 5.900511264801025}


EP_train:0:  16%|| 1081/6926 [2:04:10<11:05:44,  6.83s/it]

{'epoch': 0, 'iter': 1080, 'avg_loss': 7.281494457340152, 'avg_acc': 49.47964847363552, 'loss': 5.7203521728515625}


EP_train:0:  16%|| 1091/6926 [2:05:18<11:07:41,  6.87s/it]

{'epoch': 0, 'iter': 1090, 'avg_loss': 7.270098185779631, 'avg_acc': 49.44431714023831, 'loss': 6.042630672454834}


EP_train:0:  16%|| 1101/6926 [2:06:28<11:08:09,  6.88s/it]

{'epoch': 0, 'iter': 1100, 'avg_loss': 7.258415792120466, 'avg_acc': 49.4522025431426, 'loss': 5.6404523849487305}


EP_train:0:  16%|| 1111/6926 [2:07:36<10:55:30,  6.76s/it]

{'epoch': 0, 'iter': 1110, 'avg_loss': 7.2461077821458595, 'avg_acc': 49.47682268226823, 'loss': 6.175522327423096}


EP_train:0:  16%|| 1121/6926 [2:08:45<11:09:00,  6.91s/it]

{'epoch': 0, 'iter': 1120, 'avg_loss': 7.234284215479637, 'avg_acc': 49.506578947368425, 'loss': 5.995721340179443}


EP_train:0:  16%|| 1131/6926 [2:09:54<11:07:07,  6.91s/it]

{'epoch': 0, 'iter': 1130, 'avg_loss': 7.223184182093694, 'avg_acc': 49.508178603006186, 'loss': 5.838449478149414}


EP_train:0:  16%|| 1141/6926 [2:11:03<11:02:41,  6.87s/it]

{'epoch': 0, 'iter': 1140, 'avg_loss': 7.211702405309385, 'avg_acc': 49.52892199824715, 'loss': 6.0999040603637695}


EP_train:0:  17%|| 1151/6926 [2:12:11<10:50:05,  6.75s/it]

{'epoch': 0, 'iter': 1150, 'avg_loss': 7.19995288782592, 'avg_acc': 49.511294526498695, 'loss': 5.774099826812744}


EP_train:0:  17%|| 1161/6926 [2:13:22<11:46:44,  7.36s/it]

{'epoch': 0, 'iter': 1160, 'avg_loss': 7.189244213235676, 'avg_acc': 49.49935400516796, 'loss': 5.82039213180542}


EP_train:0:  17%|| 1171/6926 [2:14:31<10:56:53,  6.85s/it]

{'epoch': 0, 'iter': 1170, 'avg_loss': 7.179343334331561, 'avg_acc': 49.50629803586678, 'loss': 5.673587799072266}


EP_train:0:  17%|| 1181/6926 [2:15:41<11:14:00,  7.04s/it]

{'epoch': 0, 'iter': 1180, 'avg_loss': 7.1690607656370675, 'avg_acc': 49.49989415749365, 'loss': 5.9068098068237305}


EP_train:0:  17%|| 1191/6926 [2:16:51<11:04:02,  6.95s/it]

{'epoch': 0, 'iter': 1190, 'avg_loss': 7.15878241928959, 'avg_acc': 49.49097397145256, 'loss': 5.861592769622803}


EP_train:0:  17%|| 1201/6926 [2:18:01<11:13:18,  7.06s/it]

{'epoch': 0, 'iter': 1200, 'avg_loss': 7.148079333754007, 'avg_acc': 49.46659034138218, 'loss': 5.683917045593262}


EP_train:0:  17%|| 1211/6926 [2:19:10<11:01:42,  6.95s/it]

{'epoch': 0, 'iter': 1210, 'avg_loss': 7.138643024973787, 'avg_acc': 49.476156069364166, 'loss': 5.805774688720703}


EP_train:0:  18%|| 1217/6926 [2:19:52<11:18:02,  7.13s/it]