In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
import yaml
import numpy as np
from argparse import ArgumentParser

from transformers import BertTokenizer
from transformers import BertForSequenceClassification

from transformer.Models import Transformer
from transformer.Optim import ScheduledOptim
from utils import cal_loss, cal_performance, log_performances_with_cls

In [2]:
# parse argument
parser = ArgumentParser()
parser.add_argument("--lambda", dest="lambda_", type=float, default=10)

args = parser.parse_args()
lambda_ = args.lambda_
print("lambda:", lambda_)

# lambda_ = 10

In [3]:
# fix seed
def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


seed = 0
same_seeds(seed)

In [4]:
##### Read Arguments from Config File #####

# read from command line

config_path = '../configs/dpng_transformer_bert_tokenizer_with_classifier.yaml'

preprocessed = False

with open(config_path) as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    print(config)

save_model_path = config['save_model_path']
log_file = config['log_file']
use_dataset = config['dataset']

num_epochs = config['num_epochs']
batch_size = config['batch_size']

d_model = config['d_model']
d_inner_hid = config['d_inner_hid']
d_k = config['d_k']
d_v = config['d_v']

n_head = config['n_head']
n_layers = config['n_layers']
n_warmup_steps = config['n_warmup_steps']

dropout = config['dropout']
embs_share_weight = config['embs_share_weight']
proj_share_weight = config['proj_share_weight']
label_smoothing = config['label_smoothing']

train_size = config['train_size']
val_size = config['val_size']

try:
    is_bow = config['is_bow']

    if is_bow:
        bow_strategy = config['bow_strategy']
        topk = config['topk']
        if bow_strategy != 'simple_sum':
            indiv_topk = config['indiv_topk']
        else:
            # not used but use default value for simplicity
            indiv_topk = 50
        
        only_bow = config['only_bow']
        replace_predict = config['replace_predict']
        append_bow = config['append_bow']
        
except KeyError:
    is_bow = False
    
try:
    use_wordnet = config['use_wordnet']
    indiv_k = config['indiv_k']
    replace_origin = config['replace_origin']
except KeyError:
    use_wordnet = False

# todo: add to params
lr = float(config['lr'])
# lr = 5e-4
# ###################


{'save_model_path': '../models/DNPG_base_transformer_bert_tokenizer_with_classifier.pth', 'log_file': '../logs/DNPG_base_transformer_bert_tokenizer_with_classifier.txt', 'test_output_file': '../outputs/test_DNPG_transformer_bert_tokenizer_with_classifier_out.txt', 'val_output_file': '../outputs/val_DNPG_transformer_bert_tokenizer_with_classifier_out.txt', 'dataset': 'quora_bert_dataset', 'num_epochs': 50, 'batch_size': 50, 'd_model': 450, 'd_inner_hid': 512, 'd_k': 50, 'd_v': 50, 'n_head': 9, 'n_layers': 3, 'n_warmup_steps': 30000, 'dropout': 0.1, 'embs_share_weight': True, 'proj_share_weight': True, 'label_smoothing': False, 'train_size': 100000, 'val_size': 4000, 'test_size': 20000, 'is_bow': False, 'lr': '1e-3'}


In [3]:
# debug
# batch_size = 50

In [5]:
# load dataset
# preprocessed = False
if preprocessed:
    from datasets.quora_preprocessed_dataset import QuoraPreprocessedDataset as Dataset
else:
    if use_dataset == 'quora_dataset':
        from datasets.quora_dataset import QuoraDataset as Dataset
    elif use_dataset == 'quora_bert_dataset':
        from datasets.quora_bert_dataset import QuoraBertDataset as Dataset
    elif use_dataset == 'quora_bert_mask_predict_dataset':
        from datasets.quora_bert_mask_predict_dataset import QuoraBertMaskPredictDataset as Dataset
    elif use_dataset == 'quora_word_mask_prediction_dataset':
        from datasets.quora_word_mask_prediction_dataset import QuoraWordMaskPredictDataset as Dataset
    elif use_dataset == 'quora_wordnet_dataset':
        from datasets.quora_wordnet_dataset import QuoraWordnetDataset as Dataset
    else:
        raise NotImplementedError("Dataset is not defined or not implemented")

In [6]:
def create_mini_batch(samples):
    seq1_tensors = [s[0] for s in samples]
    seq2_tensors = [s[1] for s in samples]

    # zero pad
    seq1_tensors = pad_sequence(seq1_tensors,
                                  batch_first=True)

    seq2_tensors = pad_sequence(seq2_tensors,
                                  batch_first=True)    

    return seq1_tensors, seq2_tensors


In [7]:
if preprocessed:
    model_name = config_path.split('/')[-1][:-5]
    preprocessed_file = '../data/preprocess_all_{}.npy'.format(model_name)
    dataset = Dataset("train", train_size, val_size, preprocessed_file=preprocessed_file)
    val_dataset = Dataset("val", train_size, val_size, preprocessed_file=preprocessed_file)    
elif is_bow:
    dataset = Dataset(
        "train", train_size, val_size, bow_strategy=bow_strategy, topk=topk, indiv_topk=indiv_topk, 
        only_bow=only_bow, use_origin=only_bow, replace_predict=replace_predict, append_bow=append_bow
    )
    # try not to replace predict when validation?
    val_dataset = Dataset(
        "val", train_size, val_size, bow_strategy=bow_strategy, topk=topk, indiv_topk=indiv_topk, 
        only_bow=only_bow, use_origin=only_bow, replace_predict=replace_predict, append_bow=append_bow
    )
elif use_wordnet:
    dataset = Dataset("train", train_size, val_size, indiv_k=indiv_k, replace_origin=replace_origin)
    val_dataset = Dataset("val", train_size, val_size)
else:
    dataset = Dataset("train", train_size, val_size)
    val_dataset = Dataset("val", train_size, val_size)

data_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=create_mini_batch, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=create_mini_batch, shuffle=False)

In [8]:
transformer = Transformer(
    dataset.n_words,
    dataset.n_words,
    src_pad_idx=dataset.PAD_token_id,
    trg_pad_idx=dataset.PAD_token_id,
    trg_emb_prj_weight_sharing=proj_share_weight,
    emb_src_trg_weight_sharing=embs_share_weight,
    d_k=d_k,
    d_v=d_v,
    d_model=d_model,
    d_word_vec=d_model,
    d_inner=d_inner_hid,
    n_layers=n_layers,
    n_head=n_head,
    dropout=dropout,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = transformer.to(device)
print(device)

cuda


In [9]:
optimizer = ScheduledOptim(
    optim.Adam(transformer.parameters(), betas=(0.9, 0.98), eps=1e-09, lr=lr),
    2.0, d_model, n_warmup_steps)
# optimizer = optim.Adam(transformer.parameters(), betas=(0.9, 0.98), eps=1e-09, lr=lr)


In [10]:
criterion = torch.nn.CrossEntropyLoss()

In [11]:
pretrained_model_name = "bert-base-cased"
num_labels = 2
classifier = BertForSequenceClassification.from_pretrained(
    pretrained_model_name, num_labels=num_labels)
classifier = classifier.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

In [12]:
def get_classify_batch(pred, seq2):
    trg_seq = seq2.to(device)

    pred_seq = pred.max(1)[1].to(device)
    pred_seq = pred_seq.reshape(batch_size, -1)

    concat_seqs = []

    for tseq, pseq in zip(trg_seq, pred_seq):
        nopad_trg = tseq[tseq.nonzero(as_tuple=True)]
        concat_seq = torch.cat((nopad_trg, pseq))
        concat_seqs.append(concat_seq)

    concat_seqs = pad_sequence(concat_seqs,batch_first=True).long()
    masks_tensors = torch.zeros(concat_seqs.shape,
                                dtype=torch.long).to(device)
    # let bert attends only not padding ones
    masks_tensors = masks_tensors.masked_fill(
        concat_seqs != 0, 1)
    
    labels = torch.ones(len(trg_seq)).long().to(device)
    
    return concat_seqs, masks_tensors, labels

def cal_classify_loss(pred, seq2):
    concat_seqs, masks_tensors, labels = get_classify_batch(pred, seq2)
    outputs = classifier(concat_seqs)
    batch_loss = criterion(outputs[0], labels)
    
    return batch_loss
        
    

In [13]:
# train epoch
def train_epoch(model, data_loader, optimizer, device, smoothing=False):
    model.train()
    total_seq_loss, n_word_total, n_word_correct, total_cls_loss, total_batch_loss = 0, 0, 0, 0, 0
    trange = tqdm(data_loader)
    
    # debug
    count = 0
    
    for seq1, seq2 in trange:
        src_seq = seq1.to(device)
        trg_seq = seq2[:, :-1].to(device)
        gold = seq2[:, 1:].contiguous().view(-1).to(device)

        optimizer.zero_grad()
        pred = model(src_seq, trg_seq)

        seq2seq_loss, n_correct, n_word = cal_performance(
            pred, gold, dataset.PAD_token_id, smoothing) 
        
        classify_loss = cal_classify_loss(pred, seq2)
        
        loss = seq2seq_loss + lambda_ * classify_loss
        
        loss.backward()
#         optimizer.step()
        optimizer.step_and_update_lr()

        n_word_total += n_word
        n_word_correct += n_correct
        total_batch_loss += loss.item()
        total_seq_loss += seq2seq_loss.item()
        total_cls_loss += classify_loss.item()
        
        trange.set_postfix({
            'classify_loss': classify_loss.item(),
            'seq2seq_loss': seq2seq_loss.item()
        })
        
        # debug
        print('classify_loss', classify_loss.item(), 'seq2seq_loss', seq2seq_loss.item())
        count += 1
        if count == 10:
            break

    loss_per_word = total_seq_loss/n_word_total
    accuracy = n_word_correct/n_word_total
    avg_cls_loss = total_cls_loss/len(data_loader)
    avg_batch_loss = total_batch_loss/len(data_loader)
    
    return loss_per_word, accuracy, avg_cls_loss, avg_batch_loss



In [14]:
def eval_epoch(model, val_data_loader, device):
    ''' Epoch operation in evaluation phase '''

    model.eval()
    total_seq_loss, n_word_total, n_word_correct, total_cls_loss, total_batch_loss = 0, 0, 0, 0, 0
    
    trange = tqdm(val_data_loader)

    with torch.no_grad():
        for seq1, seq2 in trange:

            src_seq = seq1.to(device)
            trg_seq = seq2[:, :-1].to(device)
            gold = seq2[:, 1:].contiguous().view(-1).to(device)

            pred = model(src_seq, trg_seq)

            seq2seq_loss, n_correct, n_word = cal_performance(
            pred, gold, dataset.PAD_token_id, smoothing=False) 
        
            classify_loss = cal_classify_loss(pred, seq2)

            loss = seq2seq_loss + lambda_ * classify_loss

            n_word_total += n_word
            n_word_correct += n_correct
            total_batch_loss += loss.item()
            total_seq_loss += seq2seq_loss.item()
            total_cls_loss += classify_loss.item()

            trange.set_postfix({
                'classify_loss': classify_loss.item(),
                'seq2seq_loss': seq2seq_loss.item()
            })

    loss_per_word = total_seq_loss/n_word_total
    accuracy = n_word_correct/n_word_total
    avg_cls_loss = total_cls_loss/len(val_data_loader)    
    avg_batch_loss = total_batch_loss/len(val_data_loader)
    
    return loss_per_word, accuracy, avg_cls_loss, avg_batch_loss



In [15]:
# debug
log_file = '../logs/tmp.txt'
f = open(log_file, 'w')
best_loss = 99999

f.write("Config: {}\n".format(config))

# debug
for epoch in range(1):
# for epoch in range(num_epochs):
    print("Epoch {} / {}".format(epoch + 1, num_epochs))
    start = time.time()
    train_loss_per_word, train_accu, train_cls_loss, avg_train_loss = train_epoch(
        model, data_loader, optimizer, device, smoothing=label_smoothing
    )
    
    log_performances_with_cls(
        'Training', train_loss_per_word, train_accu, train_cls_loss, avg_train_loss, start, f
    )
    
    start = time.time()
    valid_loss_per_word, valid_accu, valid_cls_loss, avg_valid_loss = eval_epoch(model, val_data_loader, device)

    log_performances_with_cls(
        'Validation', valid_loss_per_word, valid_accu, valid_cls_loss, avg_valid_loss, start, f
    )    
    
    if avg_valid_loss < best_loss:
        # save model
        torch.save(model.state_dict(), save_model_path)
        best_loss = avg_valid_loss
        print("model saved in Epoch {}".format(epoch + 1))
        f.write("model saved in Epoch {}\n".format(epoch + 1))
        f.flush()

f.close()
    

  0%|          | 0/2000 [00:00<?, ?it/s]

Epoch 1 / 50


100%|██████████| 2000/2000 [10:23<00:00,  3.21it/s, classify_loss=0.756, seq2seq_loss=6.49e+3]
  1%|▏         | 1/80 [00:00<00:11,  6.71it/s, classify_loss=0.739, seq2seq_loss=6.76e+3]

  - (Training)   ppl:  20301.76403, accuracy: 7.528 %, avg_cls_loss:  0.75905, avg_loss:  6607.91607, elapse: 623.908 sec



100%|██████████| 80/80 [00:09<00:00,  8.56it/s, classify_loss=0.746, seq2seq_loss=6.11e+3]
  0%|          | 0/2000 [00:00<?, ?it/s]

  - (Validation) ppl:  9952.02011, accuracy: 7.821 %, avg_cls_loss:  0.74849, avg_loss:  6098.32438, elapse: 9.350 sec

Epoch 2 / 50


 14%|█▍        | 289/2000 [01:32<09:05,  3.14it/s, classify_loss=0.735, seq2seq_loss=5.72e+3]


KeyboardInterrupt: 