In [1]:
from typing import Callable, Dict, Optional, Tuple, Union, Any, collections
import time
import torch
import math
import copy 

from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from torch.nn import CrossEntropyLoss, MSELoss

from transformers import Trainer, TrainerState, TrainingArguments


from transformers.file_utils import (
    WEIGHTS_NAME,
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
    is_sagemaker_distributed_available,
    is_torch_tpu_available,
)

from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    SequenceClassifierOutputWithPast,
)

from transformers.utils import logging
logger = logging.get_logger(__name__)

from transformers import GPT2LMHeadModel,  GPT2Tokenizer, GPT2Config, GPT2LMHeadModel, GPT2Model

In [2]:
batch_size = 32
MAX_LEN = 128
CAN_NUM = 5
num_of_rerank = 10
sample_every = 1000

# some parameters I cooked up that work reasonably well
epochs = 5
learning_rate = 5e-4
warmup_steps = 1e2
epsilon = 1e-8

global debug
debug = {}

In [3]:
# I'm not really doing anything with the config buheret
configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)
VOCAB_SIZE = configuration.vocab_size
VOCAB_SIZE

50257

In [4]:
class myGPT2LMHeadModel(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
#         self.dropout = nn.Dropout(config.hidden_dropout_prob)
#         self.rerank_transformer = GPT2Model(config)
        self.rerank_linear_head = nn.Linear(config.n_embd, 1, bias=False)

        self.init_weights()

        # Model parallel
        self.model_parallel = False
        self.device_map = None
    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        is_training=False,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
            ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        global debug
        # make some model parameter not change during rerank (like dropout)
        model.eval()
        
        debug['input_ids'] = input_ids
        debug['attention_mask'] = attention_mask
        debug['labels'] = labels

        rerank_places = random.sample(np.arange(1, MAX_LEN-2-CAN_NUM*2).tolist(), k=num_of_rerank) #no duplicate
        rerank_places = np.concatenate(([0], np.sort(rerank_places), [MAX_LEN])) #add first and last tokens to make segments
        debug['rerank_places'] = rerank_places
        
        past_key_values = None
        hidden_states = []
        hidden_states_in_rerank_place = []
        labels_in_rerank_place = []
        all_rerank_hidden_states = []
        all_rerank_labels = []
        if not is_training:
            no_rerank_logits = [] #no_rerank_labels is the same with all_rerank_labels, only need to be cal when eval
        check_out_num = 0
        for i in range(num_of_rerank+1):
            #normal stage
            segment_input_ids = input_ids[:, rerank_places[i]:rerank_places[i+1]]
            segment_attention_masks = attention_mask[:, :rerank_places[i+1]]
            
            debug['segment_input_ids'] = segment_input_ids
            debug['segment_attention_masks'] = segment_attention_masks

            segment_outputs = self.transformer(
                segment_input_ids,
                attention_mask = segment_attention_masks,
                past_key_values = past_key_values
            )
            
            debug['segment_outputs'] = segment_outputs

            segment_hidden = segment_outputs[0]
            past_key_values = segment_outputs[1]

            hidden_states.append(segment_hidden)

            #rerank stage (just for rerank places)
            if i == num_of_rerank:
                break

            #rerank stage
            #get logits in rerank place
            logits_before_rerank = self.lm_head(segment_hidden[:, -1, :])
            #get candidate token ids according to the logits
            candidate_token_logits, candidate_token_ids = torch.topk(logits_before_rerank, CAN_NUM)
            rerank_labels = labels[..., rerank_places[i+1]]
            labels_in_rerank_place.append(rerank_labels)
            hidden_states_in_rerank_place.append(segment_hidden[:, -1, :])

            debug['candidate_token_ids'] = candidate_token_ids
            debug['rerank_labels'] = rerank_labels
            
            #check whether or not label in candidates
            check_labels = rerank_labels.tolist()
            check_candidates = candidate_token_ids.tolist()
            
            assert len(check_labels)==len(check_candidates)
            
            
            #when training, check whether or not label in candidates, if not we add label into candidates
            if is_training:
                #check every data in batch
                rerank_labels_this_place = []
                for j in range(len(check_labels)):
                    if check_labels[j] not in check_candidates[j]:
                        check_out_num+=1
                        replace_index = np.random.randint(CAN_NUM)
                        candidate_token_ids[j][replace_index] = check_labels[j]
                        rerank_labels_this_place.append(replace_index)          
                    else:
                        rerank_labels_this_place.append(check_candidates[j].index(check_labels[j]))
                all_rerank_labels.append(torch.tensor(rerank_labels_this_place, device=device))
            #when eval, check whether or not label in candidates, if not we do not do rerank
            else:
                rerank_labels = []
                check_in_index = []

                for j in range(len(check_labels)): 
                    if check_labels[j] in check_candidates[j]:
                        rerank_labels.append(check_candidates[j].index(check_labels[j]))
                        check_in_index.append(j)
                    else:
                        check_out_num+=1
                rerank_labels = torch.tensor(rerank_labels, device=device)
                    
                if rerank_labels.shape[0] == 0:
                    continue
                else:
                    all_rerank_labels.append(rerank_labels)

            #make context for rerank stage, 50256 is the token_id for </endoftext/>
            sep_token = torch.ones(size = [candidate_token_ids.shape[0], 1], dtype = torch.long, device=device) * 50256
            candidate_context_ids = torch.cat([sep_token, candidate_token_ids, sep_token, candidate_token_ids], -1)

            #get output from gpt2
            rerank_outputs = self.transformer(candidate_context_ids,
                            past_key_values=past_key_values,
                          )

            #get rerank logits for candidates
            rerank_hidden_states = rerank_outputs[0][:, 2+CAN_NUM:2+CAN_NUM*2]

            if not is_training:
                all_rerank_hidden_states.append(rerank_hidden_states[check_in_index])
                no_rerank_logits.append(candidate_token_logits[check_in_index])
            else:
                all_rerank_hidden_states.append(rerank_hidden_states)
        
        # cal loss, loss = normal loss + rerank loss
        loss = None

        # cal normal loss
        normal_loss = None
        
        hidden_states = torch.cat(hidden_states, 1)
        lm_logits = self.lm_head(hidden_states)

        # Shift so that tokens < n predict n
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        normal_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        # cal rerank loss
        rerank_loss = None
        
        all_rerank_hidden_states = torch.cat(all_rerank_hidden_states, 0)
        all_rerank_logits = self.rerank_linear_head(all_rerank_hidden_states)
        all_rerank_logits = torch.reshape(all_rerank_logits, [-1, CAN_NUM])
        all_rerank_labels = torch.cat(all_rerank_labels, 0)
        
        rerank_loss = loss_fct(all_rerank_logits, all_rerank_labels)
        
        #not sure which one to be used, loss is used to cal backward, others is for observation
        #loss = rerank_loss
        loss = normal_loss + rerank_loss
        
        # cal normal loss in rerank place (for comparision with rerank results), only evaluate
        normal_loss_in_rerank_place = None
        if not is_training:
            no_rerank_logits = torch.cat(no_rerank_logits, 0)
            no_rerank_logits = torch.reshape(no_rerank_logits, [-1, CAN_NUM])
            #no_rerank_labels = torch.cat(all_rerank_labels, 0) #no_rerank_labels == all_rerank_labels

            normal_loss_in_rerank_place = loss_fct(no_rerank_logits, all_rerank_labels)

#             print("\n batch info:")
#             print("there are ", check_out_num/(num_of_rerank*batch_size), "labels not in candidates")
#             print("normal_loss_in_rerank_place", normal_loss_in_rerank_place)
#             print("rerank_loss", rerank_loss)
        
#         # cal normal loss in rerank place (for comparision with rerank results)        
#         normal_loss_in_rerank_place = None

#         hidden_states_in_rerank_place = torch.cat(hidden_states_in_rerank_place, 0)
#         lm_logits_in_rerank_place = self.lm_head(hidden_states_in_rerank_place)
#         lm_logits_in_rerank_place = torch.reshape(lm_logits_in_rerank_place, [-1, VOCAB_SIZE])
#         labels_in_rerank_place = torch.cat(labels_in_rerank_place, 0)
        
#         normal_loss_in_rerank_place = loss_fct(lm_logits_in_rerank_place, labels_in_rerank_place)
        
        if is_training:
            model.train()
        
        return {"loss": loss,
                "normal_loss": normal_loss,
                "normal_loss_in_rerank_place": normal_loss_in_rerank_place,
                "rerank_loss": rerank_loss,}

In [5]:
import os
import time
import datetime

import pandas as pd
import numpy as np
import random


import torch
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
torch.manual_seed(42)

from transformers import GPT2LMHeadModel,  GPT2Tokenizer, GPT2Config, GPT2LMHeadModel
from transformers import AdamW, get_linear_schedule_with_warmup

In [19]:
train_df = pd.read_csv('data/wiki2021/wiki2021_0to4.csv')
train_df = train_df.sample(frac=1)
validation_df = train_df.iloc[:100000]
inside_validation_df = train_df.iloc[100000:110000]
train_df =train_df.iloc[110000:]

In [7]:
# Load the GPT tokenizer.
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2', bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>') #gpt2-medium
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|endoftext|>') #gpt2-medium


print("The max model length is {} for this model, although the actual embedding size for GPT small is 768".format(tokenizer.model_max_length))
print("The beginning of sequence token {} token has the id {}".format(tokenizer.convert_ids_to_tokens(tokenizer.bos_token_id), tokenizer.bos_token_id))
print("The end of sequence token {} has the id {}".format(tokenizer.convert_ids_to_tokens(tokenizer.eos_token_id), tokenizer.eos_token_id))
print("The padding token {} has the id {}".format(tokenizer.convert_ids_to_tokens(tokenizer.pad_token_id), tokenizer.pad_token_id))

# instantiate the model
model = myGPT2LMHeadModel.from_pretrained("gpt2", config=configuration)

# this step is necessary because I've added some tokens (bos_token, etc) to the embeddings
# otherwise the tokenizer and model tensors won't match up
model.resize_token_embeddings(len(tokenizer))

# Tell pytorch to run this model on the GPU.
device = torch.device("cuda")
model.cuda()

# Set the seed value all over the place to make this reproducible.
seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

The max model length is 1024 for this model, although the actual embedding size for GPT small is 768
The beginning of sequence token <|endoftext|> token has the id 50256
The end of sequence token <|endoftext|> has the id 50256
The padding token <|endoftext|> has the id 50256


Some weights of myGPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['rerank_linear_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
class GPT2Dataset(Dataset):

    def __init__(self, txt_list, tokenizer, gpt2_type="gpt2", max_length=768):

        self.tokenizer = tokenizer
        self.input_ids = []
        self.attn_masks = []
    
        for txt in txt_list:
            #encodings_dict = tokenizer('<|startoftext|>'+ txt + '<|endoftext|>', truncation=True, max_length=max_length, padding="max_length")
            encodings_dict = tokenizer('<|endoftext|>'+ txt + '<|endoftext|>', truncation=True, max_length=max_length, padding="max_length")
            debug['encodings_dict'] = encodings_dict

            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
    
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx] 

In [21]:
train_dataset = GPT2Dataset(train_df['text'], tokenizer, max_length=MAX_LEN)
validation_dataset = GPT2Dataset(validation_df['text'], tokenizer, max_length=MAX_LEN)
inside_validation_dataset = GPT2Dataset(inside_validation_df['text'], tokenizer, max_length=MAX_LEN)

In [22]:
debug['encodings_dict']

{'input_ids': [50256, 287, 1703, 354, 270, 4914, 11, 12926, 13, 40700, 1359, 428, 47872, 1621, 379, 262, 4082, 286, 262, 3660, 6142, 3356, 290, 351, 1895, 284, 10092, 15424, 9640, 29587, 329, 625, 2319, 812, 11, 262, 2646, 19788, 319, 34286, 12, 11718, 5199, 9055, 290, 465, 636, 287, 262, 6282, 286, 262, 3298, 4009, 356, 783, 760, 355, 42793, 13, 44909, 1522, 416, 262, 1936, 3173, 286, 12352, 422, 9055, 338, 19336, 13, 26362, 290, 7124, 13, 464, 11648, 717, 44119, 379, 262, 1853, 3309, 590, 13741, 11117, 11, 5442, 262, 2159, 31535, 16854, 560, 6093, 48705, 11289, 329, 39883, 290, 262, 15518, 45470, 11289, 13, 317, 717, 12268, 373, 2716, 319, 2901, 1542, 11, 1853, 13, 2202, 2693, 860, 11, 1853, 11, 17741, 4803, 32862, 262, 11648], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [23]:
# Create the DataLoaders for our training and validation datasets.
# We'll take training samples in random order. 
train_dataloader = DataLoader(
            train_dataset,  # The training samples.
            sampler = RandomSampler(train_dataset), # Select batches randomly
            batch_size = batch_size # Trains with this batch size.
        )

# For validation the order doesn't matter, so we'll just read them sequentially.
validation_dataloader = DataLoader(
            validation_dataset, # The validation samples.
            sampler = SequentialSampler(validation_dataset), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )
# For inside_validation the order doesn't matter, so we'll just read them sequentially.
inside_validation_dataloader = DataLoader(
            inside_validation_dataset, # The validation samples.
            sampler = SequentialSampler(inside_validation_dataset), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )

In [12]:
# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
optimizer = AdamW(model.parameters(),
                  lr = learning_rate,
                  eps = epsilon
                )

In [13]:
# Total number of training steps is [number of batches] x [number of epochs]. 
# (Note that this is not the same as the number of training samples).
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler.
# This changes the learning rate as the training loop progresses
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = warmup_steps, 
                                            num_training_steps = total_steps)

In [14]:
def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

In [15]:
total_t0 = time.time()

training_stats = []

model = model.to(device)

for epoch_i in range(0, epochs):

    # ========================================
    #               Training
    # ========================================

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    t0 = time.time()

    total_train_loss = 0
    total_train_normal_loss = 0
    total_train_normal_loss_in_rerank_place = 0
    total_train_rerank_loss = 0

    model.train()

    for step, batch in enumerate(train_dataloader):
        b_input_ids = batch[0].to(device)
        b_labels = batch[0].to(device)
        b_masks = batch[1].to(device)

        model.zero_grad()        

        outputs = model(  b_input_ids,
                          labels=b_labels, 
                          attention_mask = b_masks,
                          token_type_ids=None,
                          is_training=True,
                        )

        debug['outputs'] = outputs

        loss = outputs["loss"]
        normal_loss = outputs["normal_loss"]
        rerank_loss = outputs["rerank_loss"]

        batch_loss = loss.item()
        total_train_loss += batch_loss
        
        batch_normal_loss = normal_loss.item()
        total_train_normal_loss += batch_normal_loss
        
        batch_rerank_loss = rerank_loss.item()
        total_train_rerank_loss += batch_rerank_loss
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        # Get sample every x batches.
        if step % sample_every == 0 and not step == 0:
            
            t1 = time.time()

            elapsed = format_time(time.time() - t0)
            print('  Batch {:>5,}  of  {:>5,}. Loss: {:>5,}.   Elapsed: {:}.'.format(step, len(train_dataloader), batch_loss, elapsed))

            model.eval()

            total_eval_loss = 0
            total_eval_normal_loss = 0
            total_eval_normal_loss_in_rerank_place = 0
            total_eval_rerank_loss = 0

            # Evaluate data for one epoch
            for batch in inside_validation_dataloader:        
                b_input_ids = batch[0].to(device)
                b_labels = batch[0].to(device)
                b_masks = batch[1].to(device)

                with torch.no_grad():        

                    outputs  = model(b_input_ids, 
        #                            token_type_ids=None, 
                                     attention_mask = b_masks,
                                     labels=b_labels,
                                     is_training=False,)

                    loss = outputs["loss"]
                    normal_loss = outputs["normal_loss"]
                    normal_loss_in_rerank_place = outputs["normal_loss_in_rerank_place"]
                    rerank_loss = outputs["rerank_loss"]

                batch_loss = loss.item()
                total_eval_loss += batch_loss        

                batch_normal_loss = normal_loss.item()
                total_eval_normal_loss += batch_normal_loss

                batch_normal_loss_in_rerank_place = normal_loss_in_rerank_place.item()
                total_eval_normal_loss_in_rerank_place += batch_normal_loss_in_rerank_place

                batch_rerank_loss = rerank_loss.item()
                total_eval_rerank_loss += batch_rerank_loss

            avg_val_loss = total_eval_loss / len(inside_validation_dataloader)
            avg_val_normal_loss = total_eval_normal_loss / len(inside_validation_dataloader)       
            avg_val_normal_loss_in_rerank_place = total_eval_normal_loss_in_rerank_place / len(inside_validation_dataloader)       
            avg_val_rerank_loss = total_eval_rerank_loss / len(inside_validation_dataloader)    

            validation_time = format_time(time.time() - t1)    

            print("  inside Validation Loss: {0:.2f}".format(avg_val_loss))
            print("  Average inside Validation normal_loss: {0:.2f}".format(avg_val_normal_loss))
            print("  Average inside Validation normal_loss_in_rerank_place: {0:.2f}".format(avg_val_normal_loss_in_rerank_place))
            print("  Average inside Validation rerank_loss: {0:.2f}".format(avg_val_rerank_loss))
            print("  inside Validation took: {:}".format(validation_time))
            
            model.train()

    # Calculate the average loss over all of the batches.
    avg_train_loss = total_train_loss / len(train_dataloader)       
    avg_train_normal_loss = total_train_normal_loss / len(train_dataloader)      
    avg_train_rerank_loss = total_train_rerank_loss / len(train_dataloader)       
    
    # Measure how long this epoch took.
    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Average training normal_loss: {0:.2f}".format(avg_train_normal_loss))
    print("  Average training rerank_loss: {0:.2f}".format(avg_train_rerank_loss))
    print("  Training epoch took: {:}".format(training_time))
        
    # ========================================
    #               Validation
    # ========================================

    print("")
    print("Running Validation...")

    t0 = time.time()

    model.eval()

    total_eval_loss = 0
    total_eval_normal_loss = 0
    total_eval_normal_loss_in_rerank_place = 0
    total_eval_rerank_loss = 0


    # Evaluate data for one epoch
    for batch in validation_dataloader:        
        b_input_ids = batch[0].to(device)
        b_labels = batch[0].to(device)
        b_masks = batch[1].to(device)
        
        with torch.no_grad():        

            outputs  = model(b_input_ids, 
#                            token_type_ids=None, 
                             attention_mask = b_masks,
                             labels=b_labels,
                             is_training=False,)
          
            loss = outputs["loss"]
            normal_loss = outputs["normal_loss"]
            normal_loss_in_rerank_place = outputs["normal_loss_in_rerank_place"]
            rerank_loss = outputs["rerank_loss"]
        
        batch_loss = loss.item()
        total_eval_loss += batch_loss        
        
        batch_normal_loss = normal_loss.item()
        total_eval_normal_loss += batch_normal_loss
        
        batch_normal_loss_in_rerank_place = normal_loss_in_rerank_place.item()
        total_eval_normal_loss_in_rerank_place += batch_normal_loss_in_rerank_place
        
        batch_rerank_loss = rerank_loss.item()
        total_eval_rerank_loss += batch_rerank_loss

    avg_val_loss = total_eval_loss / len(validation_dataloader)
    avg_val_normal_loss = total_eval_normal_loss / len(validation_dataloader)       
    avg_val_normal_loss_in_rerank_place = total_eval_normal_loss_in_rerank_place / len(validation_dataloader)       
    avg_val_rerank_loss = total_eval_rerank_loss / len(validation_dataloader)    
    
    validation_time = format_time(time.time() - t0)    

    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    print("  Average Validation normal_loss: {0:.2f}".format(avg_val_normal_loss))
    print("  Average Validation normal_loss_in_rerank_place: {0:.2f}".format(avg_val_normal_loss_in_rerank_place))
    print("  Average Validation rerank_loss: {0:.2f}".format(avg_val_rerank_loss))
    print("  Validation took: {:}".format(validation_time))

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Training Time': training_time,
            'Validation Time': validation_time
        }
    )

print("")
print("Training complete!")
print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))
# print(f"Perplexity: {math.exp(eval_loss):.2f}")


Training...
  Batch 1,000  of  48,929. Loss: 4.907516956329346.   Elapsed: 0:38:15.
  inside Validation Loss: 4.90
  Average inside Validation normal_loss: 3.71
  Average inside Validation normal_loss_in_rerank_place: 1.01
  Average inside Validation rerank_loss: 1.18
  inside Validation took: 0:03:52
  Batch 2,000  of  48,929. Loss: 4.706604480743408.   Elapsed: 1:20:20.
  inside Validation Loss: 4.84
  Average inside Validation normal_loss: 3.69
  Average inside Validation normal_loss_in_rerank_place: 1.00
  Average inside Validation rerank_loss: 1.16
  inside Validation took: 0:03:50
  Batch 3,000  of  48,929. Loss: 4.863150596618652.   Elapsed: 2:02:18.
  inside Validation Loss: 4.84
  Average inside Validation normal_loss: 3.66
  Average inside Validation normal_loss_in_rerank_place: 1.00
  Average inside Validation rerank_loss: 1.17
  inside Validation took: 0:03:51
  Batch 4,000  of  48,929. Loss: 4.592685222625732.   Elapsed: 2:44:18.
  inside Validation Loss: 4.80
  Average i

  Batch 29,000  of  48,929. Loss: 4.401515960693359.   Elapsed: 20:12:10.
  inside Validation Loss: 4.58
  Average inside Validation normal_loss: 3.50
  Average inside Validation normal_loss_in_rerank_place: 0.97
  Average inside Validation rerank_loss: 1.09
  inside Validation took: 0:03:50
  Batch 30,000  of  48,929. Loss: 4.461573123931885.   Elapsed: 20:53:42.
  inside Validation Loss: 4.59
  Average inside Validation normal_loss: 3.50
  Average inside Validation normal_loss_in_rerank_place: 0.96
  Average inside Validation rerank_loss: 1.09
  inside Validation took: 0:03:51
  Batch 31,000  of  48,929. Loss: 4.443973064422607.   Elapsed: 21:35:14.
  inside Validation Loss: 4.58
  Average inside Validation normal_loss: 3.49
  Average inside Validation normal_loss_in_rerank_place: 0.97
  Average inside Validation rerank_loss: 1.09
  inside Validation took: 0:03:51
  Batch 32,000  of  48,929. Loss: 4.610875606536865.   Elapsed: 22:16:44.
  inside Validation Loss: 4.57
  Average inside

KeyboardInterrupt: 

In [24]:
for epoch_i in range(0, 1):

    # ========================================
    #               Training
    # ========================================

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    t0 = time.time()

    total_train_loss = 0
    total_train_normal_loss = 0
    total_train_normal_loss_in_rerank_place = 0
    total_train_rerank_loss = 0

    model.train()

    for step, batch in enumerate(train_dataloader):
        b_input_ids = batch[0].to(device)
        b_labels = batch[0].to(device)
        b_masks = batch[1].to(device)

        model.zero_grad()        

        outputs = model(  b_input_ids,
                          labels=b_labels, 
                          attention_mask = b_masks,
                          token_type_ids=None,
                          is_training=True,
                        )

        debug['outputs'] = outputs

        loss = outputs["loss"]
        normal_loss = outputs["normal_loss"]
        rerank_loss = outputs["rerank_loss"]

        batch_loss = loss.item()
        total_train_loss += batch_loss
        
        batch_normal_loss = normal_loss.item()
        total_train_normal_loss += batch_normal_loss
        
        batch_rerank_loss = rerank_loss.item()
        total_train_rerank_loss += batch_rerank_loss
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        # Get sample every x batches.
        if step % sample_every == 0 and not step == 0:
            
            t1 = time.time()

            elapsed = format_time(time.time() - t0)
            print('  Batch {:>5,}  of  {:>5,}. Loss: {:>5,}.   Elapsed: {:}.'.format(step, len(train_dataloader), batch_loss, elapsed))

            model.eval()

            total_eval_loss = 0
            total_eval_normal_loss = 0
            total_eval_normal_loss_in_rerank_place = 0
            total_eval_rerank_loss = 0

            # Evaluate data for one epoch
            for batch in inside_validation_dataloader:        
                b_input_ids = batch[0].to(device)
                b_labels = batch[0].to(device)
                b_masks = batch[1].to(device)

                with torch.no_grad():        

                    outputs  = model(b_input_ids, 
        #                            token_type_ids=None, 
                                     attention_mask = b_masks,
                                     labels=b_labels,
                                     is_training=False,)

                    loss = outputs["loss"]
                    normal_loss = outputs["normal_loss"]
                    normal_loss_in_rerank_place = outputs["normal_loss_in_rerank_place"]
                    rerank_loss = outputs["rerank_loss"]

                batch_loss = loss.item()
                total_eval_loss += batch_loss        

                batch_normal_loss = normal_loss.item()
                total_eval_normal_loss += batch_normal_loss

                batch_normal_loss_in_rerank_place = normal_loss_in_rerank_place.item()
                total_eval_normal_loss_in_rerank_place += batch_normal_loss_in_rerank_place

                batch_rerank_loss = rerank_loss.item()
                total_eval_rerank_loss += batch_rerank_loss

            avg_val_loss = total_eval_loss / len(inside_validation_dataloader)
            avg_val_normal_loss = total_eval_normal_loss / len(inside_validation_dataloader)       
            avg_val_normal_loss_in_rerank_place = total_eval_normal_loss_in_rerank_place / len(inside_validation_dataloader)       
            avg_val_rerank_loss = total_eval_rerank_loss / len(inside_validation_dataloader)    

            validation_time = format_time(time.time() - t1)    

            print("  inside Validation Loss: {0:.2f}".format(avg_val_loss))
            print("  Average inside Validation normal_loss: {0:.2f}".format(avg_val_normal_loss))
            print("  Average inside Validation normal_loss_in_rerank_place: {0:.2f}".format(avg_val_normal_loss_in_rerank_place))
            print("  Average inside Validation rerank_loss: {0:.2f}".format(avg_val_rerank_loss))
            print("  inside Validation took: {:}".format(validation_time))
            
            model.train()

    # Calculate the average loss over all of the batches.
    avg_train_loss = total_train_loss / len(train_dataloader)       
    avg_train_normal_loss = total_train_normal_loss / len(train_dataloader)      
    avg_train_rerank_loss = total_train_rerank_loss / len(train_dataloader)       
    
    # Measure how long this epoch took.
    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Average training normal_loss: {0:.2f}".format(avg_train_normal_loss))
    print("  Average training rerank_loss: {0:.2f}".format(avg_train_rerank_loss))
    print("  Training epoch took: {:}".format(training_time))
        
    # ========================================
    #               Validation
    # ========================================

    print("")
    print("Running Validation...")

    t0 = time.time()

    model.eval()

    total_eval_loss = 0
    total_eval_normal_loss = 0
    total_eval_normal_loss_in_rerank_place = 0
    total_eval_rerank_loss = 0


    # Evaluate data for one epoch
    for batch in validation_dataloader:        
        b_input_ids = batch[0].to(device)
        b_labels = batch[0].to(device)
        b_masks = batch[1].to(device)
        
        with torch.no_grad():        

            outputs  = model(b_input_ids, 
#                            token_type_ids=None, 
                             attention_mask = b_masks,
                             labels=b_labels,
                             is_training=False,)
          
            loss = outputs["loss"]
            normal_loss = outputs["normal_loss"]
            normal_loss_in_rerank_place = outputs["normal_loss_in_rerank_place"]
            rerank_loss = outputs["rerank_loss"]
        
        batch_loss = loss.item()
        total_eval_loss += batch_loss        
        
        batch_normal_loss = normal_loss.item()
        total_eval_normal_loss += batch_normal_loss
        
        batch_normal_loss_in_rerank_place = normal_loss_in_rerank_place.item()
        total_eval_normal_loss_in_rerank_place += batch_normal_loss_in_rerank_place
        
        batch_rerank_loss = rerank_loss.item()
        total_eval_rerank_loss += batch_rerank_loss

    avg_val_loss = total_eval_loss / len(validation_dataloader)
    avg_val_normal_loss = total_eval_normal_loss / len(validation_dataloader)       
    avg_val_normal_loss_in_rerank_place = total_eval_normal_loss_in_rerank_place / len(validation_dataloader)       
    avg_val_rerank_loss = total_eval_rerank_loss / len(validation_dataloader)    
    
    validation_time = format_time(time.time() - t0)    

    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    print("  Average Validation normal_loss: {0:.2f}".format(avg_val_normal_loss))
    print("  Average Validation normal_loss_in_rerank_place: {0:.2f}".format(avg_val_normal_loss_in_rerank_place))
    print("  Average Validation rerank_loss: {0:.2f}".format(avg_val_rerank_loss))
    print("  Validation took: {:}".format(validation_time))

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Training Time': training_time,
            'Validation Time': validation_time
        }
    )

print("")
print("Training complete!")
print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))
# print(f"Perplexity: {math.exp(eval_loss):.2f}")


Training...
  Batch 1,000  of  48,832. Loss: 4.53739070892334.   Elapsed: 0:38:13.
  inside Validation Loss: 4.56
  Average inside Validation normal_loss: 3.49
  Average inside Validation normal_loss_in_rerank_place: 0.97
  Average inside Validation rerank_loss: 1.07
  inside Validation took: 0:03:52


KeyboardInterrupt: 

In [25]:
debug['outputs']['loss']

tensor(4.5916, device='cuda:0', grad_fn=<AddBackward0>)

In [17]:
# model.save_pretrained("results/baseline_wiki2021")

## cases analysis

In [None]:
eval_batch = None
for batch in validation_dataloader:
    eval_batch = batch
    break

In [None]:
tokenizer.decode(eval_batch[0][3])

In [None]:
model.eval()
i = 0
#48-->colour 50-->with
predict_token_index = 50
past_key_values=None
#normal stage
segment_input_ids = eval_batch[0][3][0:predict_token_index]
labels = eval_batch[0][3][predict_token_index]
segment_input_ids = torch.reshape(torch.cat([segment_input_ids, segment_input_ids], 0), [-1, predict_token_index])
segment_input_ids = segment_input_ids.cuda()
segment_input_ids

In [None]:
print(tokenizer.decode(eval_batch[0][3]), '\n')
print(tokenizer.decode(segment_input_ids[0]))

In [None]:
segment_outputs = model.transformer(
    segment_input_ids,
    past_key_values = past_key_values
)

segment_hidden = segment_outputs[0]
past_key_values = segment_outputs[1]

In [None]:
segment_outputs[0].shape

In [None]:
logits_before_rerank = model.lm_head(segment_hidden[:, -1, :])

In [None]:
torch.topk(logits_before_rerank, CAN_NUM).values

In [None]:
nn.functional.softmax(torch.topk(logits_before_rerank, CAN_NUM).values, dim=1)

In [None]:
candidate_token_ids = torch.topk(logits_before_rerank, CAN_NUM).indices
for i in range(CAN_NUM):
    print(tokenizer.decode(candidate_token_ids[0][i]))

In [None]:
rerank_labels = labels
tokenizer.decode(rerank_labels)

In [None]:
#make context for rerank stage, 50256 is the token_id for </endoftext/>
sep_token = torch.ones(size = [candidate_token_ids.shape[0], 1], dtype = torch.long, device=device) * 50256
candidate_context_ids = torch.cat([sep_token, candidate_token_ids, sep_token, candidate_token_ids], -1)
tokenizer.decode(candidate_context_ids[0])

In [None]:
rerank_outputs = model.transformer(candidate_context_ids,
                past_key_values=past_key_values,
              )

rerank_hidden_states = rerank_outputs[0][:, 2+CAN_NUM:2+CAN_NUM*2]
rerank_hidden_states.shape

In [None]:
model.rerank_linear_head(rerank_hidden_states)

In [None]:
nn.functional.softmax(model.rerank_linear_head(rerank_hidden_states), dim=1)