In [1]:
import os
import torchfly
# need to be called before import torch
torchfly.set_random_seed(123)

In [2]:
import time
import tqdm
import logging
from apex import amp
from apex.optimizers import FusedAdam
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import torch.nn as nn
import torch
import pandas as pd
import numpy as np
import json
import argparse

from transformers import BertTokenizer
from torchfly.utils import gdrive_download
from torchfly.training.checkpointer import SimpleCheckpointer
from torchfly.modules.transformers import CachedBertEncoder, CachedBertDecoderLM, ChineseBERTBaseConfig
from torchfly.training.optimization import AdamW, WarmupLinearSchedule
from torchfly.modules.losses import SequenceFocalLoss, SequenceCrossEntropyLoss

In [3]:
torchfly.init_logging()
logger = logging.getLogger(__name__)

### set alpha and beta value

In [4]:
# alpha
a = 0.5

# beta
b = 20.0

In [5]:
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")

In [6]:
class SIAHeadlineDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.dataset = data
        self.CLS = [101]
        self.SEP = [102]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        summary, origin_title, target_title = self.dataset[idx]
        origin_title = tokenizer.convert_tokens_to_ids(origin_title)
        target_title = tokenizer.convert_tokens_to_ids(target_title)

        # source and token type
        if np.random.rand() < 0.9:
            max_len = 509 - len(origin_title)
            if len(summary) > max_len:
                summary = summary[:max_len]
            summary = tokenizer.convert_tokens_to_ids(summary)
            source = self.CLS + summary + self.SEP + origin_title + self.SEP
            source_type_ids = [0] * (len(summary) + 2) + [1] * (len(origin_title) + 1)
        else:
            max_len = 510
            if len(summary) > max_len:
                summary = summary[:max_len]
            summary = tokenizer.convert_tokens_to_ids(summary)
            source = self.CLS + summary + self.SEP
            source_type_ids = [0] * len(source)

        target = self.CLS + target_title + self.SEP

        # turn into tensors
        source = torch.LongTensor(source)
        source_mask = torch.ones(source.shape[0])
        target = torch.LongTensor(target)
        target_mask = torch.ones(target.shape[0])
        source_type_ids = torch.LongTensor(source_type_ids)

        return source, source_mask, source_type_ids, target, target_mask
    
    @staticmethod
    def mod_collate(batch):
        source, source_mask, source_type_ids, target, target_mask = zip(*batch)

        # pad sequence
        source = pad_sequence(source, batch_first=True)
        source_mask = pad_sequence(source_mask, batch_first=True).bool()
        source_type_ids = pad_sequence(source_type_ids, batch_first=True)

        target = pad_sequence(target, batch_first=True)
        target_mask = pad_sequence(target_mask, batch_first=True).bool()

        return source, source_mask, source_type_ids, target, target_mask

### define training iteration

In [7]:
def train_one_iter(batch, fp16=False):
    batch = [item.to(device) for item in batch]
    source_ids, source_mask, source_type_ids, target_ids, target_mask = batch

    _, past = encoder(source_ids,
                      token_type_ids=source_type_ids,
                      mask=source_mask)

    mask = torch.cat([source_mask, target_mask], dim=1)
    logits, _ = decoder(target_ids, mask=mask, past=past)

    out = logits[:, :-1].contiguous()
    target = target_ids[:, 1:].contiguous()
    target_mask = target_mask[:, 1:].contiguous()

    loss = criterion(out, target, target_mask, label_smoothing=0.02, reduce=True)
    loss /= num_gradients_accumulation

    # fp16 support
    with amp.scale_loss(loss, optimizer) as scaled_loss:
        scaled_loss.backward()
    
    # gradient clipping
    torch.nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(decoder.parameters()), 1.0)

    record_loss = loss.item() * num_gradients_accumulation
    perplexity = np.exp(record_loss)

    return record_loss, perplexity


def validate(dataloader):
    with torch.no_grad():
        pb = tqdm.tqdm(dataloader)
        encoder.eval()
        decoder.eval()

        total_ppl = []

        for batch in pb:
            batch = [item.to(device) for item in batch]

            source_ids, source_mask, source_type_ids, target_ids, target_mask = batch

            _, past = encoder(source_ids,
                              token_type_ids=source_type_ids,
                              mask=source_mask)

            mask = torch.cat([source_mask, target_mask], dim=1)
            logits, _ = decoder(target_ids, mask=mask, past=past)

            out = logits[:, :-1].contiguous()
            target = target_ids[:, 1:].contiguous()
            target_mask = target_mask[:, 1:].contiguous()

            loss = eval_criterion(out, target, target_mask, label_smoothing=-1, reduce="sentence")

            ppl = torch.exp(loss)
            total_ppl.extend(ppl.tolist())

    return np.mean(total_ppl)

## Load dataset

In [8]:
train_data = torch.load("../data/PHED/train_tokenized.pkl")
val_data = torch.load("../data/PHED/val_tokenized.pkl")

In [9]:
train_dataset = SIAHeadlineDataset(train_data)
val_dataset = SIAHeadlineDataset(val_data)

### dataloader

In [10]:
batch_size = 5

train_dataloader = DataLoader(
    dataset=train_dataset, 
    shuffle=True, 
    batch_size=batch_size, 
    collate_fn=SIAHeadlineDataset.mod_collate
)

val_dataloader = DataLoader(dataset=val_dataset, 
                            shuffle=False, 
                            batch_size=batch_size, 
                            collate_fn=SIAHeadlineDataset.mod_collate)


## Define Model

In [11]:
# model config
vars(ChineseBERTBaseConfig)

mappingproxy({'__module__': 'torchfly.modules.transformers.model_configs',
              'attention_dropout_prob': 0.1,
              'hidden_dropout_prob': 0.1,
              'hidden_size': 768,
              'num_attention_heads': 12,
              'num_hidden_layers': 12,
              'intermediate_size': 3072,
              'layer_norm_eps': 1e-05,
              'max_position_embeddings': 512,
              'vocab_size': 21128,
              'type_vocab_size': 2,
              '__dict__': <attribute '__dict__' of 'ChineseBERTBaseConfig' objects>,
              '__weakref__': <attribute '__weakref__' of 'ChineseBERTBaseConfig' objects>,
              '__doc__': None})

In [12]:
encoder = CachedBertEncoder(ChineseBERTBaseConfig)
decoder = CachedBertDecoderLM(ChineseBERTBaseConfig)

In [13]:
# you can either download or train your own adaption model
model_states = torch.load("../models/adapt.pth")

encoder.load_state_dict(model_states['encoder'], strict=False)
decoder.load_state_dict(model_states['decoder'], strict=False)

<All keys matched successfully>

In [14]:
# send to cuda
device = torch.device("cuda")

encoder = encoder.to(device)
decoder = decoder.to(device)

In [15]:
# set hyper-parameters
num_epochs = 10
num_gradients_accumulation = 1
num_train_optimization_steps = len(train_dataset) * num_epochs // batch_size // num_gradients_accumulation

In [16]:
# do not decay bias
param_optimizer = list(encoder.named_parameters()) + list(decoder.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {
        'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay': 0.01
    }, {
        'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay': 0.0
    }
]

In [17]:
# setup optimizier
optimizer = FusedAdam(optimizer_grouped_parameters, 
                      lr=1e-5, 
                      eps=1e-06,
                      bias_correction=False)

scheduler = WarmupLinearSchedule(
    optimizer, warmup_steps=int(num_train_optimization_steps * 0.1), t_total=num_train_optimization_steps
)

# enable fp16
[encoder, decoder], optimizer = amp.initialize([encoder, decoder], optimizer, opt_level='O1')

# set up loss
criterion = SequenceFocalLoss(gamma=a, beta=b)
eval_criterion = SequenceCrossEntropyLoss()

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [19]:
checkpointer = SimpleCheckpointer(model=None)

In [20]:
update_count = 0
start = time.time()

for ep in range(num_epochs):
    "Training"
    pb = tqdm.tqdm(train_dataloader)
    encoder.train()
    decoder.train()

    for batch in pb:
        record_loss, perplexity = train_one_iter(batch, fp16=True)
        update_count += 1

        if update_count % num_gradients_accumulation == num_gradients_accumulation - 1:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            # speed measure
            end = time.time()
            speed = batch_size * num_gradients_accumulation / (end - start)
            start = end

            pb.set_postfix(loss=record_loss, perplexity=perplexity, speed=speed)
            
    "Evaluation"
    encoder.eval()
    decoder.eval()
    ppl = validate(val_dataloader)
    checkpointer.save_checkpoint(
        str(ep), {
            "encoder": encoder.state_dict(),
            "decoder": decoder.state_dict()
        }, {"empty": None}
    )

    logger.info(f"a={a} b={b} Epoch {ep} Validation perplexity: {ppl}")

logger.info(f"Finish training of alpha={a} beta={b}")


  0%|          | 0/4600 [00:00<?, ?it/s][A
  0%|          | 0/4600 [00:00<?, ?it/s, loss=3.49, perplexity=32.7, speed=34.3][A

  0%|          | 0/300 [00:00<?, ?it/s][A[A

  1%|          | 3/300 [00:00<00:10, 28.88it/s][A[A

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0




  2%|▏         | 6/300 [00:00<00:10, 28.76it/s][A[A

  3%|▎         | 10/300 [00:00<00:09, 29.61it/s][A[A

  5%|▍         | 14/300 [00:00<00:09, 30.33it/s][A[A

  6%|▌         | 17/300 [00:00<00:09, 30.08it/s][A[A

  7%|▋         | 21/300 [00:00<00:09, 30.39it/s][A[A

  8%|▊         | 25/300 [00:00<00:09, 30.54it/s][A[A

 10%|▉         | 29/300 [00:00<00:08, 30.61it/s][A[A

 11%|█         | 32/300 [00:01<00:08, 30.32it/s][A[A

 12%|█▏        | 36/300 [00:01<00:08, 30.51it/s][A[A

 13%|█▎        | 39/300 [00:01<00:08, 29.97it/s][A[A

 14%|█▍        | 42/300 [00:01<00:08, 29.96it/s][A[A

 15%|█▌        | 46/300 [00:01<00:08, 30.38it/s][A[A

 17%|█▋        | 50/300 [00:01<00:08, 30.73it/s][A[A

 18%|█▊        | 54/300 [00:01<00:07, 30.76it/s][A[A

 19%|█▉        | 58/300 [00:01<00:07, 30.45it/s][A[A

 21%|██        | 62/300 [00:02<00:07, 30.34it/s][A[A

 22%|██▏       | 66/300 [00:02<00:07, 30.26it/s][A[A

 23%|██▎       | 70/300 [00:02<00:07, 30.12it/s

 78%|███████▊  | 234/300 [00:07<00:02, 31.04it/s][A[A[A


 79%|███████▉  | 238/300 [00:07<00:02, 30.26it/s][A[A[A


 81%|████████  | 242/300 [00:07<00:01, 30.40it/s][A[A[A


 82%|████████▏ | 246/300 [00:07<00:01, 31.14it/s][A[A[A


 83%|████████▎ | 250/300 [00:08<00:01, 31.74it/s][A[A[A


 85%|████████▍ | 254/300 [00:08<00:01, 32.11it/s][A[A[A


 86%|████████▌ | 258/300 [00:08<00:01, 32.12it/s][A[A[A


 87%|████████▋ | 262/300 [00:08<00:01, 32.33it/s][A[A[A


 89%|████████▊ | 266/300 [00:08<00:01, 31.66it/s][A[A[A


 90%|█████████ | 270/300 [00:08<00:00, 31.93it/s][A[A[A


 91%|█████████▏| 274/300 [00:08<00:00, 31.71it/s][A[A[A


 93%|█████████▎| 278/300 [00:08<00:00, 31.94it/s][A[A[A


 94%|█████████▍| 282/300 [00:09<00:00, 31.95it/s][A[A[A


 95%|█████████▌| 286/300 [00:09<00:00, 32.00it/s][A[A[A


 97%|█████████▋| 290/300 [00:09<00:00, 31.98it/s][A[A[A


 98%|█████████▊| 294/300 [00:09<00:00, 32.26it/s][A[A[A


100%|██████████| 300/300

 32%|███▏      | 97/300 [00:03<00:06, 31.46it/s][A[A[A[A[A




 34%|███▎      | 101/300 [00:03<00:06, 31.77it/s][A[A[A[A[A




 35%|███▌      | 105/300 [00:03<00:06, 31.92it/s][A[A[A[A[A




 36%|███▋      | 109/300 [00:03<00:05, 32.03it/s][A[A[A[A[A

KeyboardInterrupt: 