In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import torchfly
torchfly.set_random_seed(1)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import regex as re
import random
import itertools
import tqdm
import time
import json

from torch.utils.tensorboard import SummaryWriter
from apex import amp
from allennlp.training.checkpointer import Checkpointer
from pytorch_transformers import AdamW, WarmupLinearSchedule, GPT2Tokenizer

from torchfly.criterions import SequenceFocalLoss, SequenceCrossEntropyLoss
from gpt_model import GPT2SimpleLM
from text_utils import recoverText, normalize

In [4]:
# set tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.sep_token = "None"
# add speicial tokens in the same order as Roberta
tokenizer.add_tokens(["<s>", "<pad>", "</s>", "<unk>", "<mask>"])

5

In [5]:
class GPT2SmallConfig:
    vocab_size = 50257 + len(tokenizer.added_tokens_encoder)
    n_special = len(tokenizer.added_tokens_encoder)
    n_positions = 1024
    n_ctx = 1024
    n_embd = 768
    n_layer = 12
    n_head = 12
    resid_pdrop = 0.1
    embd_pdrop = 0.1
    attn_pdrop = 0.1
    layer_norm_epsilon = 1e-5
    initializer_range = 0.02
    gradient_checkpointing = False
    
class GPT2MediumConfig:
    vocab_size = len(tokenizer.added_tokens_encoder)
    n_special = len(tokenizer.added_tokens_encoder)
    n_positions = 1024
    n_ctx = 1024
    n_embd = 1024
    n_layer = 24
    n_head = 16
    resid_pdrop = 0.1
    embd_pdrop = 0.1
    attn_pdrop = 0.1
    layer_norm_epsilon = 1e-5
    initializer_range = 0.02
    gradient_checkpointing = True

In [6]:
model_A = GPT2SimpleLM(GPT2SmallConfig)
model_B = GPT2SimpleLM(GPT2SmallConfig)
model_A.load_state_dict(torch.load("gpt2_small.pth"))
model_B.load_state_dict(torch.load("gpt2_small.pth"))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [7]:
def align_keep_indices(batch_keep_indices):
    prev = batch_keep_indices[1]
    new_batch_keep_indices = [prev]

    for i in range(1, len(batch_keep_indices)):
        curr = batch_keep_indices[i]
        new = []

        for idx in curr:
            new.append(prev.index(idx))

        new_batch_keep_indices.append(new)
        prev = curr
        
    return new_batch_keep_indices


class MultiWOZDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        self.bos = tokenizer.encode("<s>")
        self.user_bos = tokenizer.encode("A:")
        self.system_bos = tokenizer.encode("B:")
        
        self.eos = [628, 198]
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        full_dialog = self.data[index]['log']
        
        full_dialog_tokens = []
        cur_pos = 0
        
        for turn_dialog in full_dialog:
            # cur_pos = 0
            
            # user
            user = recoverText(turn_dialog['user_delex'])
            # user = recoverText(turn_dialog['user_delex'])
            user_tokens = self.user_bos + tokenizer.encode(user) + self.eos

            user_pos = torch.arange(cur_pos, cur_pos + len(user_tokens))
            cur_pos = user_pos[-1] + 1
            
            # belief span
#             belief_tokens = self.bos + \
#                             tokenizer.encode(";".join(turn_dialog['bspan_inform'][1:])) + \
#                             self.eos
#             belief_pos = torch.arange(cur_pos, cur_pos + len(belief_tokens))
#             cur_pos = belief_pos[-1]


            # Database
            if eval(turn_dialog['pointer'])[-2:] == (1, 0):
                booked = "book"
            elif eval(turn_dialog['pointer'])[-2:] == (0, 1):
                booked = "fail"
            else:
                booked = "none"
            
            if len(turn_dialog['match']) > 0:
                num_match = int(turn_dialog['match']) if int(turn_dialog['match']) < 4 else 4
            else:
                num_match = 0
                
            database = str(num_match) + ";" + booked + ";" + turn_dialog['turn_domain'].strip("[]") + ";"
            database_tokens = tokenizer.encode(database)
            database_pos = torch.arange(cur_pos, cur_pos + len(database_tokens))
            cur_pos = database_pos[-1] + 1
            
            # System
            system = recoverText(process_text(turn_dialog['resp'], turn_dialog['turn_domain'].strip("[]")))
            system_tokens = self.system_bos + tokenizer.encode(system) + self.eos
            system_pos = torch.arange(cur_pos, cur_pos + len(system_tokens))
            cur_pos = system_pos[-1] + 1
            
            user_tokens = torch.LongTensor(user_tokens)
            system_tokens = torch.LongTensor(system_tokens)
            database_tokens = torch.LongTensor(database_tokens)
            
            full_dialog_tokens.append((user_tokens, 
                                       user_pos,
                                       system_tokens, 
                                       system_pos,
                                       database_tokens,
                                       database_pos))
#             if system_pos[-1] > 1:
#                 break

        return full_dialog_tokens
        

class Collate_Function:
    """This function handles batch collate.
    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.pad = self.tokenizer.encode("<pad>")[0]

    def __call__(self, unpacked_data):

        max_turn_len = max([len(item) for item in unpacked_data])
        
        batch_dialogs = []
        batch_keep_indices = []

        for turn_num in range(max_turn_len):

            keep_indices = []

            for batch_idx in range(len(unpacked_data)):
                if turn_num < len(unpacked_data[batch_idx]):
                    keep_indices.append(batch_idx)

            user_tokens = pad_sequence([unpacked_data[idx][turn_num][0] for idx in keep_indices], 
                                        batch_first=True,
                                        padding_value=self.pad)
            user_pos = pad_sequence([unpacked_data[idx][turn_num][1] for idx in keep_indices], 
                                        batch_first=True,
                                        padding_value=0)
            system_tokens = pad_sequence([unpacked_data[idx][turn_num][2] for idx in keep_indices], 
                                        batch_first=True,
                                        padding_value=self.pad)
            system_pos = pad_sequence([unpacked_data[idx][turn_num][3] for idx in keep_indices], 
                                        batch_first=True,
                                        padding_value=0)
            database_tokens = pad_sequence([unpacked_data[idx][turn_num][4] for idx in keep_indices], 
                                        batch_first=True,
                                        padding_value=self.pad)
            database_pos = pad_sequence([unpacked_data[idx][turn_num][5] for idx in keep_indices], 
                                        batch_first=True,
                                        padding_value=0)  

            user_mask = (user_tokens != self.pad).byte()
            system_mask = (system_tokens != self.pad).byte()
            database_mask = (database_tokens != self.pad).byte()


            batch_dialogs.append((user_tokens, user_pos, user_mask, 
                                  system_tokens, system_pos, system_mask, 
                                  database_tokens, database_pos, database_mask))
            batch_keep_indices.append(keep_indices)
            
        # align keep indices
        # batch_keep_indices = align_keep_indices(batch_keep_indices)
        return batch_dialogs, batch_keep_indices

In [8]:
def calculate_loss(logits, target, mask):
    logits = logits[:, :-1].contiguous()
    target = target[:, 1:].contiguous()
    mask = mask[:, 1:].contiguous().float()
    loss = criterion(logits, target, mask, label_smoothing=0.01, reduce=True)
    return loss

def filter_past(past, keep_indices):
    past = [item[:, keep_indices] for item in past]
    return past

def replace_punc(x):
    x = x.replace("<", "").replace(">", "")
    return x.replace(".", " .").replace(",", " .").replace("?", " ?").replace("?", " ?")

In [9]:
def process_text(text, domain):
    text = text.replace("[value_choice]", "[value_count]")
    text = text.replace("[value_people]", "[value_count]")
    text = text.replace("[value_starts]", "[value_count]")
    
    text = text.replace("[value_car]", '[taxi_type]')
    text = text.replace("[value_leave]", "[value_time]")
    text = text.replace("[value_arrive]", "[value_time]")
    text = text.replace("[value_price]", "[value_pricerange]")

    text = text.replace('[value_postcode]', f'[{domain}_postcode]')
    text = text.replace('[value_reference]', f'[{domain}_reference]')
    text = text.replace('[value_address]', f'[{domain}_address]')
    text = text.replace('[value_phone]', f'[{domain}_phone]')
    text = text.replace('[value_name]', f'[{domain}_name]')
    text = text.replace('[value_id]', f'[{domain}_id]')
    
    return text

In [10]:
with open("../yichi_data/clean_train_data.json") as f:
    train_data = json.load(f)
    
with open("../yichi_data/val_data.json") as f:
    val_data = json.load(f)
    
with open("../yichi_data/test_data.json") as f:
    test_data = json.load(f)

indices = np.arange(len(train_data))
np.random.shuffle(indices)
# use all data
indices = indices
train_data = [train_data[idx] for idx in indices]

In [11]:
train_dataset = MultiWOZDataset(train_data, tokenizer)
val_dataset = MultiWOZDataset(val_data, tokenizer)
test_dataset = MultiWOZDataset(test_data, tokenizer)

train_batch_size = 1
collate_func = Collate_Function(tokenizer)

train_dataloader = DataLoader(dataset=train_dataset, 
                              shuffle=True,
                              batch_size=train_batch_size, 
                              collate_fn=collate_func)

eval_batch_size = 4

val_dataloader = DataLoader(dataset=val_dataset, 
                          shuffle=False,
                          batch_size=eval_batch_size, 
                          collate_fn=collate_func)

test_dataloader = DataLoader(dataset=test_dataset, 
                          shuffle=False,
                          batch_size=eval_batch_size, 
                          collate_fn=collate_func)

In [12]:
criterion = SequenceFocalLoss(gamma=0.0, beta=0.0)

In [13]:
device = torch.device("cuda")
model_A = model_A.to(device)
model_B = model_B.to(device)

## Training

In [14]:
checkpointer = Checkpointer(serialization_dir="Checkpoint", 
                            keep_serialized_model_every_num_seconds=3600*2, 
                            num_serialized_models_to_keep=10)

In [15]:
# optimizer
num_epochs = 20
num_gradients_accumulation = 4
num_train_optimization_steps = num_train_optimization_steps = len(train_dataset) * num_epochs // train_batch_size // num_gradients_accumulation

param_optimizer = list(model_A.named_parameters()) + list(model_B.named_parameters())
no_decay = ['ln', 'bias', 'LayerNorm']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

optimizer = AdamW(optimizer_grouped_parameters,
                  lr=1e-4,
                  correct_bias=False)

scheduler = WarmupLinearSchedule(optimizer,
                                 warmup_steps=1000,
                                 t_total=num_train_optimization_steps)

In [16]:
# [model_A, model_B], optimizer = amp.initialize([model_A, model_B], optimizer, opt_level="O1")

In [17]:
user_weight = 1.0

def train_one_iter(batch_dialogs, batch_keep_indices, update_count, fp16=False):

    aligned_batch_keep_indices = align_keep_indices(batch_keep_indices)
   
    mask = torch.ByteTensor([]).to(device)
    prev_batch_size = batch_dialogs[0][0].shape[0]
    
    past = None
    all_logits = []
    target = []
    total_loss = 0 
    
    for turn_num in range(len(batch_keep_indices)):

        # data send to gpu
        dialogs = batch_dialogs[turn_num]
        dialogs = [item.to(device) for item in dialogs]

        user_tokens, user_pos, user_mask, \
            system_tokens, system_pos, system_mask, \
            database_tokens, database_pos, database_mask = dialogs

        # filtering algorithm
        keep_indices = aligned_batch_keep_indices[turn_num]

        if len(keep_indices) != prev_batch_size:
            past = filter_past(past, keep_indices)
            mask = mask[keep_indices, :]

        # User Utterance
        mask = torch.cat([mask, user_mask], dim=-1)
        logits, past = model_A(user_tokens, position_ids=user_pos, mask=mask, past=past)
        A_loss = calculate_loss(logits, user_tokens, user_mask)

        # Database Tokens
        mask = torch.cat([mask, database_mask], dim=-1)
        logits, past = model_B(database_tokens, position_ids=database_pos, mask=mask, past=past)
        database_loss = calculate_loss(logits, database_tokens, database_mask)        
        
        # System Response
        mask = torch.cat([mask, system_mask], dim=-1)
        logits, past = model_B(system_tokens, position_ids=system_pos, mask=mask, past=past)
        B_loss = calculate_loss(logits, system_tokens, system_mask)

        # tail
        total_loss = total_loss + user_weight * A_loss + B_loss + database_loss
        prev_batch_size = user_tokens.shape[0]

#     breakpoint
#     all_logits = torch.cat(all_logits, dim=1)
#     all_logits = all_logits[:, :-1].contiguous()

#     target = torch.cat(target, dim=1)
#     target = target[:, 1:].contiguous()
    
#     target_mask = torch.ones_like(target).float()
    
#     total_loss = criterion(all_logits, target, target_mask, label_smoothing=0.02, reduce=True)

    # gradient accumulation
    total_loss /= len(batch_keep_indices)
    total_loss /= num_gradients_accumulation 
    
    if fp16:
        with amp.scale_loss(total_loss, optimizer) as scaled_loss:
            scaled_loss.backward()
    else:
        total_loss.backward()
        
    record_loss = total_loss.item() * num_gradients_accumulation
    perplexity = np.exp(record_loss)
    
    return record_loss, perplexity

In [18]:
def calculate_length(batch_dialogs):
    total_sum = 0
    for turn_num in range(len(batch_keep_indices)):
        total_sum += batch_dialogs[turn_num][2].sum(-1) + \
                        batch_dialogs[turn_num][5].sum(-1) + \
                        batch_dialogs[turn_num][8].sum(-1)
    return total_sum

In [19]:
update_count = 0
progress_bar = tqdm.tqdm_notebook
start = time.time()

for ep in range(num_epochs):

    "Training"
    pbar = progress_bar(train_dataloader)
    model_A.train()
    model_B.train()
    
    for batch_dialogs, batch_keep_indices in pbar:
        
        if calculate_length(batch_dialogs).item() > 900:
            print("exceed limit")
            continue
            
        if len(batch_keep_indices) < 2:
            continue
        
        record_loss, perplexity = train_one_iter(batch_dialogs, batch_keep_indices, update_count, fp16=False)

        update_count += 1

        if update_count % num_gradients_accumulation == num_gradients_accumulation - 1:
            # update for gradient accumulation
            scheduler.step()
            torch.nn.utils.clip_grad_norm_(model_A.parameters(), 5.0)
            torch.nn.utils.clip_grad_norm_(model_B.parameters(), 5.0)
            optimizer.step()
            optimizer.zero_grad()
            
            # speed measure
            end = time.time()
            speed = train_batch_size * num_gradients_accumulation / (end - start)
            start = end
            
            # show progress
            pbar.set_postfix(loss=record_loss, perplexity=perplexity, speed=speed)
    
#     "Evaluation"
#     print(f"Epoch {ep} Validation")
#     eval_res = validate(val_dataloader, val_data)
#     print(eval_res)
    
#     print(f"Epoch {ep} Test")
#     eval_res = validate(test_dataloader, test_data)
#     print(eval_res)
    
    checkpointer.save_checkpoint(ep, 
                                 [model_A.state_dict(), model_B.state_dict()],
                                 {"None": None},
                                 True
                                 )

HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



HBox(children=(IntProgress(value=0, max=5091), HTML(value='')))

exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit
exceed limit



In [20]:
res = []

for batch_dialogs, batch_keep_indices in pbar:
    length = calculate_length(batch_dialogs).item()
    res.append(length)

In [21]:
res = np.array(res)

In [22]:
def validate(dataloader, data):

    model_A.eval()
    model_B.eval()

    temperature = 0.5

    all_response = []

    for batch_dialogs, batch_keep_indices in tqdm.tqdm_notebook(dataloader):

        aligned_batch_keep_indices = align_keep_indices(batch_keep_indices)
        past = None
        generated_responses = [[] for i in range(batch_dialogs[0][0].shape[0])]

        mask = torch.ByteTensor([]).to(device)
        prev_batch_size = batch_dialogs[0][0].shape[0]

        with torch.no_grad():
            for turn_num in range(len(batch_keep_indices)):
                # data send to gpu
                dialogs = batch_dialogs[turn_num]
                dialogs = [item.to(device) for item in dialogs]

                user_tokens, user_pos, user_mask, \
                    system_tokens, system_pos, system_mask, \
                    belief_tokens, belief_pos, belief_mask = dialogs

                # batch filtering algorithm
                keep_indices = aligned_batch_keep_indices[turn_num]

                if len(keep_indices) != prev_batch_size:
                    past = filter_past(past, keep_indices)
                    mask = mask[keep_indices, :]

                # define some initials
                cur_batch_size = user_tokens.shape[0]
                flags = np.ones(cur_batch_size)
                generated_tokens = [[] for i in range(cur_batch_size)]

                # feed in user
                mask = torch.cat([mask, user_mask], dim=-1)
                _, past = model_A(user_tokens, position_ids=user_pos, mask=mask, past=past)

                # response generation
                response = []


                # first three tokens
                prev_input = system_tokens[:, :3]
                cur_pos = system_pos[:, :3]
                temp_past = past
                temp_mask = F.pad(mask, pad=(0,3), value=1)

                # feed into B
                logits, temp_past = model_B(prev_input, position_ids=cur_pos, mask=temp_mask, past=temp_past)
                # set current position
                cur_pos = cur_pos[:, -1].unsqueeze(1) + 1

                for i in range(50):
                    logits = logits[:, -1, :] / temperature
                    prev_tokens = torch.argmax(logits, dim=-1)
                    np_prev_tokens = prev_tokens.cpu().numpy()
                    # nucleus sampling
                    # logits = top_filtering(logits, top_k=100, top_p=0.7)
                    # probs = F.softmax(logits, -1)
                    # prev_input = torch.multinomial(probs, num_samples=1)

                    # add to generated tokens list
                    count = 0
                    for idx, value in enumerate(flags):
                        if value != 0:
                            generated_tokens[idx].append(np_prev_tokens[count])
                            count += 1

                    # filtering algorithm
                    if np.any(np_prev_tokens == 628):
                        # set flags 0
                        count = 0
                        for idx, value in enumerate(flags):
                            if value == 1:
                                if np_prev_tokens[count] == 628:
                                    flags[idx] = 0
                                count += 1
                        # compute which one to keep
                        keep_indices = np.argwhere(np_prev_tokens != 628).squeeze(1)
                        # filter
                        prev_tokens = prev_tokens[keep_indices.tolist()]
                        cur_pos = cur_pos[keep_indices.tolist(), :]
                        temp_mask = temp_mask[keep_indices.tolist(), :]
                        temp_past = [item[:, keep_indices.tolist()] for item in temp_past]
                        np_prev_tokens = np_prev_tokens[keep_indices.tolist()]

                    if np.all(flags == 0):
                        break

                    # prepare for the next token        
                    temp_mask = F.pad(temp_mask, pad=(0, 1), value=1)
                    logits, temp_past = model_B(prev_tokens.view(-1, 1), 
                                           position_ids=cur_pos, 
                                           mask=temp_mask, 
                                           past=temp_past)
                    cur_pos = cur_pos + 1

                # real system_tokens feed in
                mask = torch.cat([mask, system_mask], dim=-1)
                _, past = model_B(system_tokens, position_ids=system_pos, mask=mask, past=past)

                # inject into generated_responses_list
                decoded_responses = [tokenizer.decode(item).replace("\n", "") for item in generated_tokens]
                count = 0
                for idx in batch_keep_indices[turn_num]:
                    generated_responses[idx].append(decoded_responses[count])
                    count += 1

            # add to the final responses        
            for item in generated_responses:
                all_response.extend(item)
                
    # Stage 2
    #   prepare for metric eval
    dialog_data = []
    count = 0
    all_results = []

    for i in range(len(data)):
        raw_dialog = data[i]

        for turn_num in range(len(raw_dialog)):

            replaced_response = clean_sentence(
                replace_punc(raw_dialog[turn_num]["replaced_response"].lower().replace("slot", "SLOT")), entity_dict)

            generated_response = clean_sentence(replace_punc(all_response[count].lower().replace("slot", "SLOT")), entity_dict)

            dialog_data.append({"dial_id": raw_dialog[turn_num]["dial_id"],
                                "turn_num": raw_dialog[turn_num]["turn_num"],
                                "response": replaced_response,
                                "generated_response":generated_response 
                              })
            count += 1
            
    sccuess_f1 = success_f1_metric(dialog_data)
    bleu = bleu_metric(dialog_data)

    return {"bleu": bleu,
            "sccuess_f1": sccuess_f1
            }