In [1]:
import os
import time
import datetime
import torch
import math
import copy 
import random
from packaging import version
import pandas as pd
import numpy as np

from typing import Callable, Dict, Optional, Tuple, Union, Any, collections

from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from torch.nn import CrossEntropyLoss, MSELoss
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler

from transformers import GPT2LMHeadModel,  GPT2Tokenizer, GPT2Config, GPT2LMHeadModel, GPT2Model
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import Trainer, TrainerState, TrainingArguments

from transformers.file_utils import (
    WEIGHTS_NAME,
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
    is_sagemaker_distributed_available,
    is_torch_tpu_available,
)

from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    SequenceClassifierOutputWithPast,
)

from transformers.utils import logging
logger = logging.get_logger(__name__)

In [2]:
torch.manual_seed(42)

batch_size = 4
MAX_LEN = 128
CAN_NUM = 100
num_of_rerank = 30

# some parameters I cooked up that work reasonably well
epochs = 1
learning_rate = 5e-4
warmup_steps = 1e2
epsilon = 1e-8

global debug
debug = {}

In [3]:
# I'm not really doing anything with the config buheret
configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)
VOCAB_SIZE = configuration.vocab_size
VOCAB_SIZE

50257

In [4]:
class myGPT2LMHeadModel(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
#         self.dropout = nn.Dropout(config.hidden_dropout_prob)
#         self.rerank_transformer = GPT2Model(config)
        self.rerank_linear_head = nn.Linear(config.n_embd, 1, bias=False)

        self.init_weights()
    def forward(
        self,
        input_ids=None,
        labels=None,
        is_training=False,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
            ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
        """
        
        global debug
        # make some model parameter not change during rerank (like dropout) ??????????????
        # model.eval()
        self.transformer.eval()
        
        debug['input_ids'] = input_ids
        debug['labels'] = labels

        # rerank_places = random.sample(np.arange(1, MAX_LEN-2-CAN_NUM*2).tolist(), k=num_of_rerank) #no duplicate
        rerank_places = random.sample(np.arange(1, MAX_LEN).tolist(), k=num_of_rerank) #no duplicate
        rerank_places = np.concatenate(([0], np.sort(rerank_places), [MAX_LEN])) #add first and last tokens to make segments
        debug['rerank_places'] = rerank_places
        
        past_key_values = None
        all_rerank_hidden_states = []
        all_rerank_labels = []
        
        if not is_training:
            all_candidate_token_ids = []
            all_input_ids = []
            all_prediction_ids = []
            
        no_rerank_logits = []
        check_out_num = 0
        for i in range(num_of_rerank+1):
            #normal stage
            segment_input_ids = input_ids[:, rerank_places[i]:rerank_places[i+1]]
            
            debug['segment_input_ids'] = segment_input_ids

            segment_outputs = self.transformer(
                segment_input_ids,
                past_key_values = past_key_values
            )
            
            debug['segment_outputs'] = segment_outputs

            segment_hidden = segment_outputs[0]
            past_key_values = segment_outputs[1]

            #rerank stage (just for rerank places)
            if i == num_of_rerank:
                break

            #rerank stage
            #get logits in rerank place
            logits_before_rerank = self.lm_head(segment_hidden[:, -1, :])
            #get candidate token ids according to the logits
            candidate_token_logits, candidate_token_ids = torch.topk(logits_before_rerank, CAN_NUM)
            rerank_labels = labels[..., rerank_places[i+1]]

            debug['candidate_token_ids'] = candidate_token_ids
            debug['rerank_labels'] = rerank_labels
            
            #check whether or not label in candidates
            check_labels = rerank_labels.tolist()
            check_candidates = candidate_token_ids.tolist()
            
            assert len(check_labels)==len(check_candidates)
            
            #check whether or not label in candidates, if not we do not do rerank
            rerank_labels = []
            check_in_index = []

            for j in range(len(check_labels)): 
                if check_labels[j] in check_candidates[j]:
                    rerank_labels.append(check_candidates[j].index(check_labels[j]))
                    check_in_index.append(j)
                else:
                    check_out_num+=1
            rerank_labels = torch.tensor(rerank_labels, device=input_ids.device)

            if rerank_labels.shape[0] == 0:
                continue
            else:
                all_rerank_labels.append(rerank_labels)

            #make context for rerank stage, 50256 is the token_id for </endoftext/>
            sep_token = torch.ones(size = [candidate_token_ids.shape[0], 1], dtype = torch.long, device=input_ids.device) * 50256
            candidate_context_ids = torch.cat([sep_token, candidate_token_ids, sep_token, candidate_token_ids], -1)

            
            #get output from gpt2
            rerank_outputs = self.transformer(candidate_context_ids,
                            past_key_values=past_key_values,
                          )

            #get rerank logits for candidates
            rerank_hidden_states = rerank_outputs[0][:, 2+CAN_NUM:2+CAN_NUM*2]

            all_rerank_hidden_states.append(rerank_hidden_states[check_in_index])
            no_rerank_logits.append(candidate_token_logits[check_in_index])
            
            if not is_training:
                all_candidate_token_ids.append(candidate_token_ids[check_in_index])
                all_prediction_ids.append(input_ids[:, rerank_places[i+1]][check_in_index])
                all_input_ids.append(input_ids[:, :rerank_places[i+1]][check_in_index])

        
#         print("\n batch info:")
#         print("there are ", check_out_num/(num_of_rerank*batch_size), "labels not in candidates")
    
        if is_training:
            #model.train()
            self.transformer.train()
        
            #-------------------------------------------------------------------------
            # cal loss, loss = normal loss + rerank loss
            loss_fct = CrossEntropyLoss(reduction='none')

            # cal rerank loss
            rerank_loss = None

            all_rerank_hidden_states = torch.cat(all_rerank_hidden_states, 0)
            all_rerank_logits = self.rerank_linear_head(all_rerank_hidden_states)
            all_rerank_logits = torch.reshape(all_rerank_logits, [-1, CAN_NUM])
            all_rerank_labels = torch.cat(all_rerank_labels, 0)

            rerank_loss = loss_fct(all_rerank_logits, all_rerank_labels)

            # cal normal loss in rerank place (for comparision with rerank results), only evaluate
            normal_loss_in_rerank_place = None


            no_rerank_logits = torch.cat(no_rerank_logits, 0)
            no_rerank_logits = torch.reshape(no_rerank_logits, [-1, CAN_NUM])
            #no_rerank_labels = torch.cat(all_rerank_labels, 0) #no_rerank_labels == all_rerank_labels

            normal_loss_in_rerank_place = loss_fct(no_rerank_logits, all_rerank_labels)


            return {"normal_loss_in_rerank_place": normal_loss_in_rerank_place,
                    "rerank_loss": rerank_loss,}
        # for evaluation, we will evaluate the model's performance on different difficult level
        else:
            #-------------------------------------------------------------------------
            # cal loss, loss = normal loss + rerank loss
            loss_fct = CrossEntropyLoss(reduction='none')

            # cal rerank loss
            rerank_loss = None

            all_rerank_hidden_states = torch.cat(all_rerank_hidden_states, 0)
            all_rerank_logits = self.rerank_linear_head(all_rerank_hidden_states)
            all_rerank_logits = torch.reshape(all_rerank_logits, [-1, CAN_NUM])
            all_rerank_labels = torch.cat(all_rerank_labels, 0)

            rerank_loss = loss_fct(all_rerank_logits, all_rerank_labels)

            # cal normal loss in rerank place (for comparision with rerank results), only evaluate
            normal_loss_in_rerank_place = None
            no_rerank_logits = torch.cat(no_rerank_logits, 0)
            no_rerank_logits = torch.reshape(no_rerank_logits, [-1, CAN_NUM])
            #no_rerank_labels = torch.cat(all_rerank_labels, 0) #no_rerank_labels == all_rerank_labels

            normal_loss_in_rerank_place = loss_fct(no_rerank_logits, all_rerank_labels)
            
            for i in range(len(all_input_ids)):
                target = torch.ones(size = [all_input_ids[i].shape[0], MAX_LEN], dtype = torch.long, device=input_ids.device) * 50256
                target[:, :all_input_ids[i].shape[1]] = all_input_ids[i]
                all_input_ids[i] = target
                
            all_input_ids = torch.cat(all_input_ids, 0)
            all_prediction_ids = torch.cat(all_prediction_ids, 0)
            all_candidate_token_ids = torch.cat(all_candidate_token_ids, 0)

            return {"all_rerank_logits": all_rerank_logits,
                    "no_rerank_logits": no_rerank_logits,
                    "difficult_level": all_rerank_labels,
                    "rerank_loss": rerank_loss,
                    "normal_loss_in_rerank_place": normal_loss_in_rerank_place,
                    "all_candidate_token_ids": all_candidate_token_ids,
                    "all_prediction_ids": all_prediction_ids,
                    "all_input_ids": all_input_ids}

In [5]:
# the data is in "/mnt/nfs/work1/llcao/zonghaiyao/LM/data/wikitext-2/my_train.csv"
train_df = pd.read_csv("data/wikitext-2/my_train.csv")
validation_df = pd.read_csv("data/wikitext-2/my_validation.csv")
test_df = pd.read_csv("data/wikitext-2/my_test.csv")

In [6]:
# Load the GPT tokenizer.
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2', bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>') #gpt2-medium
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|endoftext|>') #gpt2-medium


print("The max model length is {} for this model, although the actual embedding size for GPT small is 768".format(tokenizer.model_max_length))
print("The beginning of sequence token {} token has the id {}".format(tokenizer.convert_ids_to_tokens(tokenizer.bos_token_id), tokenizer.bos_token_id))
print("The end of sequence token {} has the id {}".format(tokenizer.convert_ids_to_tokens(tokenizer.eos_token_id), tokenizer.eos_token_id))
print("The padding token {} has the id {}".format(tokenizer.convert_ids_to_tokens(tokenizer.pad_token_id), tokenizer.pad_token_id))

# instantiate the model
# model = myGPT2LMHeadModel.from_pretrained("gpt2", config=configuration)
model = myGPT2LMHeadModel.from_pretrained("results/baseline_wiki2/exclude_cases_label_not_in_candidates/0", 
                                          config=configuration)

# this step is necessary because I've added some tokens (bos_token, etc) to the embeddings
# otherwise the tokenizer and model tensors won't match up
model.resize_token_embeddings(len(tokenizer))

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
model = torch.nn.DataParallel(model) # Encapsulate the model


# Tell pytorch to run this model on the GPU.
device = torch.device("cuda")
model.cuda()

# Set the seed value all over the place to make this reproducible.
seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

The max model length is 1024 for this model, although the actual embedding size for GPT small is 768
The beginning of sequence token <|endoftext|> token has the id 50256
The end of sequence token <|endoftext|> has the id 50256
The padding token <|endoftext|> has the id 50256
Let's use 4 GPUs!


In [7]:
class GPT2Dataset(Dataset):

    def __init__(self, txt_list, tokenizer, gpt2_type="gpt2", max_length=768):

        self.tokenizer = tokenizer
        self.input_ids = []
    
        for txt in txt_list:
            #encodings_dict = tokenizer('<|startoftext|>'+ txt + '<|endoftext|>', truncation=True, max_length=max_length, padding="max_length")
            encodings_dict = tokenizer(txt, truncation=True, max_length=max_length, padding="max_length")
            debug['encodings_dict'] = encodings_dict

            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
    
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx]

In [8]:
train_dataset = GPT2Dataset(train_df['text'], tokenizer, max_length=MAX_LEN)
validation_dataset = GPT2Dataset(validation_df['text'], tokenizer, max_length=MAX_LEN)
test_dataset = GPT2Dataset(test_df['text'], tokenizer, max_length=MAX_LEN)

In [9]:
# Create the DataLoaders for our training and validation datasets.
# We'll take training samples in random order. 
train_dataloader = DataLoader(
            train_dataset,  # The training samples.
            sampler = RandomSampler(train_dataset), # Select batches randomly
            batch_size = batch_size # Trains with this batch size.
        )

# For validation the order doesn't matter, so we'll just read them sequentially.
validation_dataloader = DataLoader(
            validation_dataset, # The validation samples.
            sampler = SequentialSampler(validation_dataset), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )

# For test the order doesn't matter, so we'll just read them sequentially.
test_dataloader = DataLoader(
            test_dataset, # The validation samples.
            sampler = SequentialSampler(test_dataset), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )

In [10]:
# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
optimizer = AdamW(model.parameters(),
                  lr = learning_rate,
                  eps = epsilon
                )

In [11]:
# Total number of training steps is [number of batches] x [number of epochs]. 
# (Note that this is not the same as the number of training samples).
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler.
# This changes the learning rate as the training loop progresses
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = warmup_steps, 
                                            num_training_steps = total_steps)

In [12]:
def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

In [13]:
total_t0 = time.time()

training_stats = []

model = model.to(device)

for epoch_i in range(0, epochs):

    # ========================================
    #               Training
    # ========================================

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    t0 = time.time()

    total_train_loss = 0
    total_train_normal_loss = 0
    total_train_normal_loss_in_rerank_place = 0
    total_train_rerank_loss = 0

    model.train()

    for step, batch in enumerate(train_dataloader):        
        model.zero_grad()

        outputs = model(  input_ids=batch,         #batch_input_ids
                          labels=batch,            #batch_labels
                          is_training=True,
                       )

        debug['outputs'] = outputs

        normal_loss = outputs["normal_loss_in_rerank_place"].mean()
        rerank_loss = outputs["rerank_loss"].mean()

        loss = normal_loss + rerank_loss
        
        batch_loss = loss.item()
        total_train_loss += batch_loss
        
        batch_normal_loss = normal_loss.item()
        total_train_normal_loss += batch_normal_loss
        
        batch_rerank_loss = rerank_loss.item()
        total_train_rerank_loss += batch_rerank_loss
        
        loss.backward()
        optimizer.step()
        scheduler.step()

    # Calculate the average loss over all of the batches.
    avg_train_loss = total_train_loss / len(train_dataloader)       
    avg_train_normal_loss = total_train_normal_loss / len(train_dataloader)      
    avg_train_rerank_loss = total_train_rerank_loss / len(train_dataloader)       
    
    # Measure how long this epoch took.
    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss:", avg_train_loss)
    print("  Average training normal_loss:", avg_train_normal_loss)
    print("  Average training rerank_loss:", avg_train_rerank_loss)
    print("  Training epoch took:", training_time)
        
    # ========================================
    #               Validation
    # ========================================

    print("")
    print("Running Validation...")

    t0 = time.time()

    model.eval()

    total_eval_loss = 0
    total_eval_normal_loss = 0
    total_eval_rerank_loss = 0
    
    all_evaluate_rerank_logits = []
    all_evaluate_normal_logits = []
    all_evaluate_difficult_level = []
    all_evaluate_rerank_loss = []
    all_evaluate_normal_loss_in_rerank_place = []
    all_evaluate_ground_true = []
    all_evaluate_inputs_text = []
    all_evaluate_candidate_token = []

    # Evaluate data for one epoch
    for batch in validation_dataloader:
        with torch.no_grad():        

            outputs = model(  input_ids=batch,         #batch_input_ids
                              labels=batch,            #batch_labels
                              is_training=False,
                           )
          
            normal_loss = outputs["normal_loss_in_rerank_place"].mean()
            rerank_loss = outputs["rerank_loss"].mean()
            
            loss = normal_loss + rerank_loss
        
        batch_loss = loss.item()
        total_eval_loss += batch_loss        
        
        batch_normal_loss = normal_loss.item()
        total_eval_normal_loss += batch_normal_loss
        
        batch_rerank_loss = rerank_loss.item()
        total_eval_rerank_loss += batch_rerank_loss
        
        #fine-grained evaluation
        all_evaluate_rerank_logits.extend(outputs["all_rerank_logits"].tolist())
        all_evaluate_normal_logits.extend(outputs["no_rerank_logits"].tolist())
        all_evaluate_difficult_level.extend(outputs["difficult_level"].tolist())
        all_evaluate_rerank_loss.extend(outputs["rerank_loss"].tolist())
        all_evaluate_normal_loss_in_rerank_place.extend(outputs["normal_loss_in_rerank_place"].tolist())    
        all_evaluate_ground_true.extend(tokenizer.batch_decode(outputs['all_prediction_ids'], skip_special_tokens=True))
        all_evaluate_inputs_text.extend(tokenizer.batch_decode(outputs['all_input_ids'], skip_special_tokens=True))
        all_evaluate_candidate_token.extend([tokenizer.batch_decode(ids) for ids in outputs['all_candidate_token_ids']])

    avg_val_loss = total_eval_loss / len(validation_dataloader)
    avg_val_normal_loss = total_eval_normal_loss / len(validation_dataloader)          
    avg_val_rerank_loss = total_eval_rerank_loss / len(validation_dataloader)    
    
    validation_time = format_time(time.time() - t0)    

    print("  Validation Loss:", avg_val_loss)
    print("  Average Validation normal_loss:", avg_val_normal_loss)
    print("  Average Validation rerank_loss:", avg_val_rerank_loss)
    print("  Validation took:", validation_time)
    
    fine_grained_evaluation = pd.DataFrame({"rerank_logits": all_evaluate_rerank_logits,
                                            "normal_logits": all_evaluate_normal_logits,
                                            "ground_true_difficulty_level": all_evaluate_difficult_level,
                                            "rerank_loss": all_evaluate_rerank_loss,
                                            "normal_loss": all_evaluate_normal_loss_in_rerank_place,
                                            "ground_true": all_evaluate_ground_true,
                                            "inputs_text": all_evaluate_inputs_text,
                                            "candidate_tokens": all_evaluate_candidate_token,
                                            })
    
    def softmax(x):
        """Compute softmax values for each sets of scores in x."""
        e_x = np.exp(x - np.max(x))
        return e_x / e_x.sum()

    def cal_entropy_difficulty_level(x):
        x = softmax(x)
        entropy_difficulty_level = 0
        for i in range(CAN_NUM):
            entropy_difficulty_level -= x[i] * np.log10(x[i])

        return entropy_difficulty_level
    
    fine_grained_evaluation['entropy_difficulty_level'] = fine_grained_evaluation['normal_logits'].apply(cal_entropy_difficulty_level)
#     fine_grained_evaluation.to_pickle("results/baseline_wiki2/exclude_cases_label_not_in_candidates/0/fine_grained_evaluation.pkl")
    
    # Record all statistics from this epoch.
    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Training Time': training_time,
            'Validation Time': validation_time,
            'fine_grained_evaluation': fine_grained_evaluation,
        }
    )

print("")
print("Training complete!")
print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))
# print(f"Perplexity: {math.exp(eval_loss):.2f}")


Training...


RuntimeError: CUDA out of memory. Tried to allocate 148.00 MiB (GPU 0; 11.93 GiB total capacity; 10.61 GiB already allocated; 106.06 MiB free; 11.29 GiB reserved in total by PyTorch)

In [None]:
fine_grained_evaluation

In [None]:
# a = pd.read_pickle("results/baseline_wiki2/exclude_cases_label_not_in_candidates/0/fine_grained_evaluation.pkl")

In [None]:
# model.module.save_pretrained("results/baseline_wiki2/exclude_cases_label_not_in_candidates/0")

## cases analysis

In [None]:
eval_batch = None
for batch in validation_dataloader:
    eval_batch = batch
    break

In [None]:
tokenizer.decode(eval_batch[0][3])

In [None]:
model.eval()
i = 0
#48-->colour 50-->with
predict_token_index = 50
past_key_values=None
#normal stage
segment_input_ids = eval_batch[0][3][0:predict_token_index]
labels = eval_batch[0][3][predict_token_index]
segment_input_ids = torch.reshape(torch.cat([segment_input_ids, segment_input_ids], 0), [-1, predict_token_index])
segment_input_ids = segment_input_ids.cuda()
segment_input_ids

In [None]:
print(tokenizer.decode(eval_batch[0][3]), '\n')
print(tokenizer.decode(segment_input_ids[0]))

In [None]:
segment_outputs = model.transformer(
    segment_input_ids,
    past_key_values = past_key_values
)

segment_hidden = segment_outputs[0]
past_key_values = segment_outputs[1]

In [None]:
segment_outputs[0].shape

In [None]:
logits_before_rerank = model.lm_head(segment_hidden[:, -1, :])

In [None]:
torch.topk(logits_before_rerank, CAN_NUM).values

In [None]:
nn.functional.softmax(torch.topk(logits_before_rerank, CAN_NUM).values, dim=1)

In [None]:
candidate_token_ids = torch.topk(logits_before_rerank, CAN_NUM).indices
for i in range(CAN_NUM):
    print(tokenizer.decode(candidate_token_ids[0][i]))

In [None]:
rerank_labels = labels
tokenizer.decode(rerank_labels)

In [None]:
#make context for rerank stage, 50256 is the token_id for </endoftext/>
sep_token = torch.ones(size = [candidate_token_ids.shape[0], 1], dtype = torch.long, device=device) * 50256
candidate_context_ids = torch.cat([sep_token, candidate_token_ids, sep_token, candidate_token_ids], -1)
tokenizer.decode(candidate_context_ids[0])

In [None]:
rerank_outputs = model.transformer(candidate_context_ids,
                past_key_values=past_key_values,
              )

rerank_hidden_states = rerank_outputs[0][:, 2+CAN_NUM:2+CAN_NUM*2]
rerank_hidden_states.shape

In [None]:
model.rerank_linear_head(rerank_hidden_states)

In [None]:
nn.functional.softmax(model.rerank_linear_head(rerank_hidden_states), dim=1)