In [1]:
from typing import Callable, Dict, Optional, Tuple, Union, Any, collections
import torch
import copy 
import random
import numpy as np

from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss

from torch.utils.data import Dataset

from transformers.file_utils import (
    WEIGHTS_NAME,
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
    is_sagemaker_distributed_available,
    is_torch_tpu_available,
)

from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    SequenceClassifierOutputWithPast,
)

from transformers import GPT2LMHeadModel, GPT2Model

import os
import time
import datetime
import torch
import math
import copy 
import random
from packaging import version
import pandas as pd
import numpy as np
import pickle

from torch import nn
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler

from transformers import GPT2LMHeadModel,  GPT2Tokenizer, GPT2Config, GPT2LMHeadModel, GPT2Model
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import Trainer, TrainerState, TrainingArguments

from transformers.utils import logging
logger = logging.get_logger(__name__)

In [2]:
batch_size = 100
MAX_LEN = 128
CAN_NUM = 20
num_of_rerank = 30

# some parameters I cooked up that work reasonably well
epochs = 1
learning_rate = 5e-4
warmup_steps = 1e2
epsilon = 1e-8

# Set the seed value all over the place to make this reproducible.
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

SAVE_PATH = "/mnt/nfs/work1/llcao/zonghaiyao/LM/"

global debug
debug = {}
debug['check_in_num'] = 0
debug['total_sample_num'] = 0

In [3]:
class wiki2021_GPT2Dataset(Dataset):
    def __init__(self, input_ids):
        self.input_ids = input_ids
    
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx]

In [4]:
class rerankGPT2LMHeadModel_labelInCandidate(GPT2LMHeadModel):
    def __init__(self, config, MAX_LEN, CAN_NUM, num_of_rerank):
        super().__init__(config)
        self.MAX_LEN = MAX_LEN
        self.CAN_NUM = CAN_NUM
        self.num_of_rerank = num_of_rerank
        self.VOCAB_SIZE = config.vocab_size
        
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
#         self.dropout = nn.Dropout(config.hidden_dropout_prob)
#         self.rerank_transformer = GPT2Model(config)
        self.rerank_linear_head = nn.Linear(config.n_embd, 1, bias=False)

        self.init_weights()
    def forward(
        self,
        input_ids=None,
        labels=None,
        is_training=False,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
            ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
        """       
        global debug
        # rerank_places = random.sample(np.arange(1, MAX_LEN-2-CAN_NUM*2).tolist(), k=num_of_rerank) #no duplicate
        rerank_places = random.sample(np.arange(1, self.MAX_LEN).tolist(), k=self.num_of_rerank) #no duplicate
        rerank_places = np.concatenate(([0], np.sort(rerank_places), [self.MAX_LEN])) #add first and last tokens to make segments
        
        past_key_values = None
        hidden_states_in_rerank_place = []
        stage1_logits_in_rerank_place = []
        all_rerank_hidden_states = []
        all_rerank_labels = []
        if not is_training:
            no_rerank_logits = [] #no_rerank_labels is the same with all_rerank_labels, only need to be cal when eval
        check_out_num = 0
        for i in range(self.num_of_rerank+1):
            #normal stage
            segment_input_ids = input_ids[:, rerank_places[i]:rerank_places[i+1]]

            segment_outputs = self.transformer(
                segment_input_ids,
                past_key_values = past_key_values
            )

            segment_hidden = segment_outputs[0]
            past_key_values = segment_outputs[1]

            #rerank stage (just for rerank places)
            if i == self.num_of_rerank:
                break

            #rerank stage
            #get logits in rerank place
            logits_before_rerank = self.lm_head(segment_hidden[:, -1, :])
            #get candidate token ids according to the logits
            candidate_token_logits, candidate_token_ids = torch.topk(logits_before_rerank, self.CAN_NUM)
            rerank_labels = labels[..., rerank_places[i+1]]
#             labels_in_rerank_place.append(rerank_labels)
#             hidden_states_in_rerank_place.append(segment_hidden[:, -1, :])

            
            #check whether or not label in candidates
            check_labels = rerank_labels.tolist()
            check_candidates = candidate_token_ids.tolist()
            
#             debug['logits_before_rerank'] = logits_before_rerank
#             debug['check_labels'] = check_labels
#             debug['check_candidates'] = check_candidates
#             debug['candidate_token_logits'] = candidate_token_logits
#             debug['candidate_token_ids'] = candidate_token_ids
            
            assert len(check_labels)==len(check_candidates)
            
            
            #when training, check whether or not label in candidates, if not we add label into candidates
            #check every data in batch
            rerank_labels_this_place = []
            for j in range(len(check_labels)):
                if check_labels[j] not in check_candidates[j]:
                    check_out_num+=1
                    replace_index = np.random.randint(self.CAN_NUM)
                    candidate_token_ids[j][replace_index] = check_labels[j]
                    candidate_token_logits[j][replace_index] = logits_before_rerank[j][check_labels[j]]
                    rerank_labels_this_place.append(replace_index)          
                else:
                    rerank_labels_this_place.append(check_candidates[j].index(check_labels[j]))
            all_rerank_labels.append(torch.tensor(rerank_labels_this_place, device=input_ids.device))
            stage1_logits_in_rerank_place.append(candidate_token_logits)

            #make context for rerank stage, 50256 is the token_id for </endoftext/>
            sep_token = torch.ones(size = [candidate_token_ids.shape[0], 1], dtype = torch.long, device=input_ids.device) * 50256
            candidate_context_ids = torch.cat([sep_token, candidate_token_ids, sep_token, candidate_token_ids], -1)
       
            #get output from gpt2
            rerank_outputs = self.transformer(candidate_context_ids,
                            past_key_values=past_key_values,
                          )

            #get rerank logits for candidates
            rerank_hidden_states = rerank_outputs[0][:, 2+self.CAN_NUM:2+self.CAN_NUM*2]

            all_rerank_hidden_states.append(rerank_hidden_states)
            

#         print("\n batch info:")
#         print("there are ", check_out_num/(self.num_of_rerank*batch_size), "labels not in candidates")
        
        #-------------------------------------------------------------------------
        # cal loss, loss = normal loss + rerank loss
        loss_fct = CrossEntropyLoss(reduction='none')
        
        debug['all_rerank_labels'] = all_rerank_labels
        debug['hidden_states_in_rerank_place'] = hidden_states_in_rerank_place
        
        # cal rerank loss
        rerank_loss = None
        
        all_rerank_hidden_states = torch.cat(all_rerank_hidden_states, 0)
        all_rerank_logits = self.rerank_linear_head(all_rerank_hidden_states)
        all_rerank_logits = torch.reshape(all_rerank_logits, [-1, self.CAN_NUM])
        all_rerank_labels = torch.cat(all_rerank_labels, 0)
        
        rerank_loss = loss_fct(all_rerank_logits, all_rerank_labels)
        

        # cal normal loss in rerank place (for comparision with rerank results)        
        normal_loss_in_rerank_place = None
        
        stage1_logits_in_rerank_place = torch.cat(stage1_logits_in_rerank_place, 0)        
        normal_loss_in_rerank_place = loss_fct(stage1_logits_in_rerank_place, all_rerank_labels)
        
        
        stage1_logits = stage1_logits_in_rerank_place.tolist()
        stage2_logits = all_rerank_logits.tolist()
        labels = all_rerank_labels.tolist()
        
        def cal_gt_value_rank_in_array(logits_array, ground_true_place):
            
            array = (-1) * np.array(logits_array) #for reverse ranking
            temp = array.argsort()
            ranks = np.empty_like(temp)
            ranks[temp] = np.arange(len(array))
            return ranks[ground_true_place]
        ranks_stage1 = []
        ranks_stage2 = []
        for i in range(all_rerank_labels.shape[0]):
            ranks_stage1.append(cal_gt_value_rank_in_array(stage1_logits[i], labels[i]))
            ranks_stage2.append(cal_gt_value_rank_in_array(stage2_logits[i], labels[i]))
            
        debug['ranks_stage1'] = ranks_stage1
        debug['ranks_stage2'] = ranks_stage2
        
        print("rerank_loss:", rerank_loss[:5])
        print("normal_loss:", normal_loss_in_rerank_place[:5])
        print("ranks_stage1:", ranks_stage1[:5])
        print("ranks_stage2:", ranks_stage2[:5])

#         hidden_states_in_rerank_place = torch.cat(hidden_states_in_rerank_place, 0)
#         lm_logits_in_rerank_place = self.lm_head(hidden_states_in_rerank_place)
#         lm_logits_in_rerank_place = torch.reshape(lm_logits_in_rerank_place, [-1, self.VOCAB_SIZE])
#         labels_in_rerank_place = torch.cat(labels_in_rerank_place, 0)

#         normal_loss_in_rerank_place = loss_fct(lm_logits_in_rerank_place, labels_in_rerank_place)
        
        return {"normal_loss_in_rerank_place": normal_loss_in_rerank_place,
                "rerank_loss": rerank_loss,
                "ranks_stage1": ranks_stage1,
                "ranks_stage2": ranks_stage2}

In [5]:
# I'm not really doing anything with the config buheret
configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)

# Load the GPT tokenizer.
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|endoftext|>') #gpt2-medium

# instantiate the model
model = rerankGPT2LMHeadModel_labelInCandidate.from_pretrained(SAVE_PATH + "results/baseline_wiki2021/exclude_cases_label_not_in_candidates_canNUM20/120000", 
                                                                                    config=configuration,
                                                                                    MAX_LEN = MAX_LEN,
                                                                                    CAN_NUM = CAN_NUM, 
                                                                                   num_of_rerank = num_of_rerank)



# this step is necessary because I've added some tokens (bos_token, etc) to the embeddings
# otherwise the tokenizer and model tensors won't match up
model.resize_token_embeddings(len(tokenizer))

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
model = torch.nn.DataParallel(model) # Encapsulate the model


# Tell pytorch to run this model on the GPU.
device = torch.device("cuda")
model.cuda()

DataParallel(
  (module): rerankGPT2LMHeadModel_labelInCandidate(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0): Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, i

In [6]:
# with open(SAVE_PATH + 'data/wiki2021/wiki2021_0to4_train_dataset.pkl', 'rb') as f:
#     train_input_ids = pickle.load(f)
with open(SAVE_PATH + 'data/wiki2021/wiki2021_0to4_validation_dataset.pkl', 'rb') as f:
    validation_input_ids = pickle.load(f)
with open(SAVE_PATH + 'data/wiki2021/wiki2021_0to4_inside_validation_dataset.pkl', 'rb') as f:
    inside_validation_input_ids = pickle.load(f)
    
# train_dataset = wiki2021_GPT2Dataset(train_input_ids)
validation_dataset = wiki2021_GPT2Dataset(validation_input_ids)
inside_validation_dataset = wiki2021_GPT2Dataset(inside_validation_input_ids)

# Create the DataLoaders for our training and validation datasets.
# We'll take training samples in random order. 
# train_dataloader = DataLoader(
#             train_dataset,  # The training samples.
#             sampler = RandomSampler(train_dataset), # Select batches randomly
#             batch_size = batch_size # Trains with this batch size.
#         )

# For validation the order doesn't matter, so we'll just read them sequentially.
validation_dataloader = DataLoader(
            validation_dataset, # The validation samples.
            sampler = SequentialSampler(validation_dataset), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )

# For inside_validation the order doesn't matter, so we'll just read them sequentially.
inside_validation_dataloader = DataLoader(
            inside_validation_dataset, # The validation samples.
            sampler = SequentialSampler(inside_validation_dataset), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )

In [7]:
def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

In [8]:
model = model.to(device)

# ========================================
#               Validation
# ========================================

print("")
print("Running Validation...")

t0 = time.time()

model.eval()

total_eval_loss = 0
total_eval_normal_loss = 0
total_eval_normal_loss_in_rerank_place = 0
total_eval_rerank_loss = 0


# Evaluate data for one epoch
stage1_gt_ranking = []
stage2_gt_ranking = []
for batch in inside_validation_dataloader:
    with torch.no_grad():        

        outputs = model(  input_ids=batch,         #batch_input_ids
                          labels=batch,            #batch_labels
                          is_training=False,
                       )

        normal_loss = outputs["normal_loss_in_rerank_place"].mean()
        rerank_loss = outputs["rerank_loss"].mean()
        stage1_gt_ranking.extend(outputs["ranks_stage1"])
        stage2_gt_ranking.extend(outputs["ranks_stage2"])

        loss = normal_loss + rerank_loss

    batch_loss = loss.item()
    total_eval_loss += batch_loss        

    batch_normal_loss = normal_loss.item()
    total_eval_normal_loss += batch_normal_loss

    batch_rerank_loss = rerank_loss.item()
    total_eval_rerank_loss += batch_rerank_loss

avg_val_loss = total_eval_loss / len(inside_validation_dataloader)
avg_val_normal_loss = total_eval_normal_loss / len(inside_validation_dataloader)          
avg_val_rerank_loss = total_eval_rerank_loss / len(inside_validation_dataloader)    

validation_time = format_time(time.time() - t0)

print("  Validation Loss: {0:.2f}".format(avg_val_loss))
print("  Average Validation normal_loss: {0:.2f}".format(avg_val_normal_loss))
print("  Average Validation rerank_loss: {0:.2f}".format(avg_val_rerank_loss))
print("  Validation took: {:}".format(validation_time))


Running Validation...
rerank_loss: tensor([1.4274, 4.1998, 1.5362, 4.2876, 4.6577], device='cuda:0')
normal_loss: tensor([1.2354, 6.7931, 1.3189, 7.0454, 5.0614], device='cuda:0')
ranks_stage1: [0, 19, 0, 19, 19]
ranks_stage2: [0, 11, 0, 16, 13]
rerank_loss: tensor([4.4119, 1.7867, 2.2997, 5.6258, 1.0694], device='cuda:0')
normal_loss: tensor([6.4540, 1.6272, 2.3816, 8.2732, 1.0126], device='cuda:0')
ranks_stage1: [19, 1, 1, 19, 0]
ranks_stage2: [10, 0, 2, 19, 0]
rerank_loss: tensor([4.7456, 2.1354, 5.7403, 3.5086, 3.0893], device='cuda:0')
normal_loss: tensor([7.2976, 5.9622, 5.1238, 3.1874, 3.2023], device='cuda:0')
ranks_stage1: [19, 19, 19, 5, 6]
ranks_stage2: [18, 0, 18, 6, 6]
rerank_loss: tensor([3.5934, 4.5537, 4.3872, 2.4179, 1.0724], device='cuda:0')
normal_loss: tensor([5.7373, 5.2250, 6.0206, 2.5663, 0.1460], device='cuda:0')
ranks_stage1: [19, 19, 19, 4, 0]
ranks_stage2: [11, 17, 17, 2, 0]
rerank_loss: tensor([0.2390, 2.6041, 2.0155, 6.0171, 1.6813], device='cuda:0')
norma

rerank_loss: tensor([2.4362, 4.4805, 1.5887, 4.4401, 2.3216], device='cuda:0')
normal_loss: tensor([1.9764, 5.3079, 1.4148, 6.4867, 2.9774], device='cuda:0')
ranks_stage1: [2, 19, 0, 19, 4]
ranks_stage2: [2, 15, 0, 15, 2]
rerank_loss: tensor([0.8098, 0.9130, 1.8284, 2.4341, 2.5226], device='cuda:0')
normal_loss: tensor([1.0256, 0.7138, 2.7832, 4.2641, 2.2188], device='cuda:0')
ranks_stage1: [0, 0, 1, 19, 2]
ranks_stage2: [0, 0, 1, 2, 3]
rerank_loss: tensor([4.9060, 1.9892, 3.8972, 3.7233, 1.6379], device='cuda:0')
normal_loss: tensor([7.9648, 2.4831, 6.6280, 3.4957, 1.4975], device='cuda:0')
ranks_stage1: [19, 1, 19, 12, 0]
ranks_stage2: [9, 2, 10, 13, 1]
rerank_loss: tensor([2.6091, 1.0450, 1.9943, 3.0801, 2.8382], device='cuda:0')
normal_loss: tensor([2.4305, 0.8470, 1.8174, 3.0470, 6.0331], device='cuda:0')
ranks_stage1: [2, 0, 1, 5, 19]
ranks_stage2: [2, 0, 1, 9, 5]
rerank_loss: tensor([3.0096, 5.2559, 3.6233, 3.9639, 7.0518], device='cuda:0')
normal_loss: tensor([5.7622, 5.9190, 4

rerank_loss: tensor([0.1852, 2.6631, 7.6209, 5.4579, 4.1735], device='cuda:0')
normal_loss: tensor([1.1859, 1.1177, 7.4793, 4.1737, 3.9832], device='cuda:0')
ranks_stage1: [0, 1, 19, 13, 12]
ranks_stage2: [0, 1, 19, 13, 14]
rerank_loss: tensor([1.6750, 0.0042, 1.7767, 1.5351, 0.6876], device='cuda:0')
normal_loss: tensor([1.9364, 0.0065, 1.7736, 1.3810, 0.4297], device='cuda:0')
ranks_stage1: [1, 0, 1, 0, 0]
ranks_stage2: [1, 0, 1, 1, 0]
rerank_loss: tensor([5.6280, 0.2494, 3.0909, 4.6833, 3.2236], device='cuda:0')
normal_loss: tensor([6.9606, 0.2095, 5.3746, 6.0167, 5.3425], device='cuda:0')
ranks_stage1: [19, 0, 19, 19, 19]
ranks_stage2: [19, 0, 5, 18, 7]
rerank_loss: tensor([2.0642, 1.3654, 2.6942, 5.8503, 5.6522], device='cuda:0')
normal_loss: tensor([3.2007, 2.0254, 2.5499, 6.7583, 7.7974], device='cuda:0')
ranks_stage1: [1, 3, 2, 19, 19]
ranks_stage2: [1, 1, 2, 19, 17]
rerank_loss: tensor([4.1270, 5.2383, 0.6075, 0.3011, 0.3566], device='cuda:0')
normal_loss: tensor([4.8290, 5.78

In [9]:
def mean_reciprocal_rank(gt_ranks, CAN_NUM=5):
    rs = []
    for i in range(len(gt_ranks)):
        tmp = np.zeros(CAN_NUM)
        tmp[gt_ranks[i]] = 1
        rs.append(tmp)
    
    rs = (np.asarray(r).nonzero()[0] for r in rs)
    return np.mean([1. / (r[0] + 1) if r.size else 0. for r in rs]) 


print("normal MRR is ", mean_reciprocal_rank(gt_ranks=stage1_gt_ranking, CAN_NUM=20))
print("rerank MRR is ", mean_reciprocal_rank(gt_ranks=stage2_gt_ranking, CAN_NUM=20))
print()

normal MRR is  0.43380786470336496
rerank MRR is  0.46583580429337307



In [11]:
stage1_gt_in_candidates_cases = []
stage1_gt_not_in_candidates_cases = []
stage2_gt_in_candidates_cases = []
stage2_gt_not_in_candidates_cases = []
for i in range(len(stage1_gt_ranking)):
    if stage1_gt_ranking[i] == 19:
        stage1_gt_not_in_candidates_cases.append(stage1_gt_ranking[i])
        stage2_gt_not_in_candidates_cases.append(stage2_gt_ranking[i])
    else:
        stage1_gt_in_candidates_cases.append(stage1_gt_ranking[i])
        stage2_gt_in_candidates_cases.append(stage2_gt_ranking[i])

In [16]:
print("gt_in_candidates_cases normal MRR is ", mean_reciprocal_rank(gt_ranks=stage1_gt_in_candidates_cases, CAN_NUM=20))
print("gt_in_candidates_cases rerank MRR is ", mean_reciprocal_rank(gt_ranks=stage2_gt_in_candidates_cases, CAN_NUM=20))

gt_in_candidates_cases normal MRR is  0.6504284335209369
gt_in_candidates_cases rerank MRR is  0.6553313709017413


In [17]:
print("gt_not_in_candidates_cases normal MRR is ", mean_reciprocal_rank(gt_ranks=stage1_gt_not_in_candidates_cases, CAN_NUM=20))
print("gt_not_in_candidates_cases rerank MRR is ", mean_reciprocal_rank(gt_ranks=stage2_gt_not_in_candidates_cases, CAN_NUM=20))


gt_not_in_candidates_cases normal MRR is  0.05000000000000002
gt_not_in_candidates_cases rerank MRR is  0.1300879610127941
