<h1>训练QA<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#初始化" data-toc-modified-id="初始化-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>初始化</a></span></li><li><span><a href="#QA-dataset" data-toc-modified-id="QA-dataset-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>QA dataset</a></span></li><li><span><a href="#QA-model" data-toc-modified-id="QA-model-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>QA model</a></span></li><li><span><a href="#训练" data-toc-modified-id="训练-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>训练</a></span><ul class="toc-item"><li><span><a href="#超参数" data-toc-modified-id="超参数-4.1"><span class="toc-item-num">4.1&nbsp;&nbsp;</span>超参数</a></span></li><li><span><a href="#辅助函数" data-toc-modified-id="辅助函数-4.2"><span class="toc-item-num">4.2&nbsp;&nbsp;</span>辅助函数</a></span></li><li><span><a href="#实例化" data-toc-modified-id="实例化-4.3"><span class="toc-item-num">4.3&nbsp;&nbsp;</span>实例化</a></span></li><li><span><a href="#开始训练" data-toc-modified-id="开始训练-4.4"><span class="toc-item-num">4.4&nbsp;&nbsp;</span>开始训练</a></span></li></ul></li></ul></div>

# 初始化

In [1]:
import os
import sys
from argparse import Namespace
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import numpy as np
import scipy.sparse as sp

import ujson as json
from tqdm.notebook import tqdm
from transformers import AutoConfig, AutoTokenizer
from apex import amp

# QA dataset

In [2]:
from datasets import HotpotQA_QA_Dataset, find_ans_spans, generate_QA_batches

# QA model

In [3]:
from XLNET_QA import XLNetForQuestionAnswering_customized as XLNetQAModel

# 训练

## 超参数

In [4]:
args = Namespace(
    # Data and path information
    json_train_path=r'./data/hotpot_train_v1.1.json',
    json_train_mini_path=r'./data/hotpot_train_mini.json',
    model_state_file = "HotpotQA_QA.pt",
    save_dir = 'save_cache',
    hotpotQA_item_folder = 'save_preprocess_new',
    model_path = '/g/data/models/xlnet-large-cased',
    use_proxy = False,
    proxies={"http_proxy": "127.0.0.1:10809",
             "https_proxy": "127.0.0.1:10809"},

    # Training hyper parameter
    num_epochs=10,
    learning_rate=1e-3,
    batch_size=32,
    seed=1337,
    early_stopping_criteria=5,

    # Runtime hyper parameter
    cuda=True,
    device=None,
    catch_keyboard_interrupt=True,
    reload_from_files=True,
    expand_filepaths_to_save_dir=True,
    )

In [5]:
def set_seed_everywhere(seed, cuda):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed_all(seed)

def handle_dirs(dirpath):
    if not os.path.exists(dirpath):
        os.makedirs(dirpath)

if not torch.cuda.is_available():
    args.cuda = False

if not args.device:
    args.device = torch.device("cuda" if args.cuda else "cpu")

if args.expand_filepaths_to_save_dir:
    args.model_state_file = os.path.join(args.save_dir,args.model_state_file)

set_seed_everywhere(args.seed, args.cuda)

handle_dirs(args.save_dir)

print("Using: {}".format(args.device))

Using: cuda


## 辅助函数

In [6]:
def make_train_state(args):
    return {'stop_early': False,
            'early_stopping_step': 0,
            'early_stopping_best_val': 1e8,
            'learning_rate': args.learning_rate,
            'epoch_index': 0,
            
            'train_running_loss': [],
            'train_running_ans_span_accuracy': [],
            'train_running_yes_no_span_accuracy': [],
            
            'val_running_loss': [],
            'val_running_ans_span_accuracy': [],
            'val_running_yes_no_span_accuracy': [],

            'test_running_loss': [],
            'test_running_ans_span_accuracy': [],
            'test_running_yes_no_span_accuracy': [],

            'model_filename': args.model_state_file}

def update_train_state(args, model, optimizer, train_state):
    # Save one model at least
    if train_state['epoch_index'] == 0:
        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'amp': amp.state_dict()
        }
        torch.save(checkpoint, train_state['model_filename'])
        
        train_state['stop_early'] = False

    # Save model if performance improved
    elif train_state['epoch_index'] >= 1:
        loss_tm1, loss_t = train_state['val_running_loss'][-2:]
         
        # If loss worsened
        if loss_t >= loss_tm1:
            # Update step
            train_state['early_stopping_step'] += 1
        # Loss decreased
        else:
            # Save the best model
            if loss_t < train_state['early_stopping_best_val']:
                
                checkpoint = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'amp': amp.state_dict()
                }
                torch.save(checkpoint, train_state['model_filename'])

                train_state['early_stopping_best_val'] = loss_t

            # Reset early stopping step
            train_state['early_stopping_step'] = 0

        # Stop early ?
        train_state['stop_early'] = \
            train_state['early_stopping_step'] >= args.early_stopping_criteria

    return train_state

def compute_span_accuracy(start_logits, start_positions, end_logits, end_positions):

    _, start_indices = start_logits.max(-1)
    start_indices.eq(start_positions)

    _, end_indices = end_logits.max(-1)
    end_indices.eq(end_positions)

    correct = start_indices.eq(start_positions) * end_indices.eq(end_positions)

    numerator = correct.sum().item()
    denominator = start_positions.ne(-100).sum().item()

    # all questions are yes-no type.
    if denominator == 0: return 0

    return float(numerator) / denominator

def compute_accuracy(logits, labels):

    _, logits_indices = logits.max(-1)    

    numerator = torch.eq(logits_indices, labels).sum().item()
    denominator = labels.ne(-100).sum().item()

    if denominator == 0: return 0
    return float(numerator) / denominator

## 实例化

In [7]:
tokenizer_XLNET = AutoTokenizer.from_pretrained(args.model_path)
classifier = XLNetQAModel.from_pretrained(args.model_path, local_files_only=True)
classifier = classifier.to(args.device)
classifier.train()

loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

optimizer = optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()),
                      lr=args.learning_rate)

# Initialization
opt_level = 'O1'
classifier, optimizer = amp.initialize(classifier, optimizer, opt_level=opt_level)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,\
                        mode='min', factor=0.5, patience=3)

train_state = make_train_state(args)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [8]:
dataset = HotpotQA_QA_Dataset.build_dataset(args.json_train_path)
dataset.set_parameters(tokenizer = tokenizer_XLNET, topN_sents = 6)
print(dataset)

HotpotQA QA Dataset. mode: train. size: 63312. sents num: 6


## 开始训练

In [9]:
try:
    epoch_bar = tqdm(desc='training routine',
                    total=args.num_epochs,
                    position=0)

    dataset.set_split('train')
    train_bar = tqdm(desc='split=train',
                    total=dataset.get_num_batches(args.batch_size), 
                    position=1)

    dataset.set_split('val')
    val_bar = tqdm(desc='split=val',
                    total=dataset.get_num_batches(args.batch_size), 
                    position=1)

    for epoch_index in range(args.num_epochs):

        train_state['epoch_index'] = epoch_index

        dataset.set_split('train')
        batch_generator = generate_QA_batches(dataset,
                                        batch_size=args.batch_size, 
                                        device=args.device)
        running_loss = 0.0
        running_ans_span_accuracy = 0.0
        running_yes_no_span_accuracy = 0.0

        classifier.train()

        for batch_index, batch_dict in enumerate(batch_generator):
            optimizer.zero_grad()
            yes_no_span = batch_dict.pop('yes_no_span')
            res = classifier(**batch_dict)
            start_logits, end_logits, cls_logits = res[0], res[1], res[2]
            
            start_loss = loss_fct(start_logits, batch_dict['start_positions'])
            end_loss = loss_fct(end_logits, batch_dict['end_positions'])
            start_end_loss = (start_loss + end_loss) / 2
            yes_no_span_loss = loss_fct(cls_logits, yes_no_span) / 2

            ans_span_accuracy = compute_span_accuracy(start_logits, batch_dict['start_positions'],
                                                        end_logits, batch_dict['end_positions'])
            yes_no_span_accuracy = compute_accuracy(cls_logits, yes_no_span)
            
            loss = start_end_loss + yes_no_span_loss
            running_loss += (loss.item() - running_loss) / (batch_index + 1)
            running_ans_span_accuracy  += (ans_span_accuracy - running_ans_span_accuracy) / (batch_index + 1)
            running_yes_no_span_accuracy  += (yes_no_span_accuracy - running_yes_no_span_accuracy) / (batch_index + 1)
            
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            optimizer.step()
            
            # update bar
            train_bar.set_postfix(running_loss=running_loss,
                                  running_ans_span_accuracy=running_ans_span_accuracy,
                                  running_yes_no_span_accuracy=running_yes_no_span_accuracy,
                                  epoch=epoch_index)
            train_bar.update()

        train_state['train_running_loss'].append(running_loss)
        train_state['train_running_ans_span_accuracy'].append(running_ans_span_accuracy)
        train_state['train_running_yes_no_span_accuracy'].append(running_yes_no_span_accuracy)

        # Iterate over val dataset
        # setup: batch generator, set loss and acc to 0; set eval mode on

        dataset.set_split('val')
        batch_generator = generate_QA_batches(dataset,
                                        batch_size=args.batch_size, 
                                        device=args.device)
        running_loss = 0.0
        running_ans_span_accuracy = 0.0
        running_yes_no_span_accuracy = 0.0
        
        classifier.eval()

        for batch_index, batch_dict in enumerate(batch_generator):
            with torch.no_grad():

                yes_no_span = batch_dict.pop('yes_no_span')
                res = classifier(**batch_dict)
                start_logits, end_logits, cls_logits = res[0], res[1], res[2]

                start_loss = loss_fct(start_logits, batch_dict['start_positions'])
                end_loss = loss_fct(end_logits, batch_dict['end_positions'])
                start_end_loss = (start_loss + end_loss) / 2
                yes_no_span_loss = loss_fct(cls_logits, yes_no_span) / 2

                ans_span_accuracy = compute_span_accuracy(start_logits, batch_dict['start_positions'],
                                                            end_logits, batch_dict['end_positions'])
                yes_no_span_accuracy = compute_accuracy(cls_logits, yes_no_span)

                loss = start_end_loss + yes_no_span_loss
                running_loss += (loss.item() - running_loss) / (batch_index + 1)
                running_ans_span_accuracy  += (ans_span_accuracy - running_ans_span_accuracy) / (batch_index + 1)
                running_yes_no_span_accuracy  += (yes_no_span_accuracy - running_yes_no_span_accuracy) / (batch_index + 1)



            val_bar.set_postfix(running_loss=running_loss,
                                  running_ans_span_accuracy=running_ans_span_accuracy,
                                  running_yes_no_span_accuracy=running_yes_no_span_accuracy,
                                  epoch=epoch_index)
            val_bar.update()

        train_state['val_running_loss'].append(running_loss)
        train_state['val_running_ans_span_accuracy'].append(running_ans_span_accuracy)
        train_state['val_running_yes_no_span_accuracy'].append(running_yes_no_span_accuracy)

        train_state = update_train_state(args=args, model=classifier, 
                                         optimizer = optimizer,
                                         train_state=train_state)

        scheduler.step(train_state['val_running_loss'][-1])

        train_bar.n = 0
        val_bar.n = 0
        epoch_bar.update()

        if train_state['stop_early']:
            break
except KeyboardInterrupt:
    print("Exiting loop")

HBox(children=(FloatProgress(value=0.0, description='training routine', max=10.0, style=ProgressStyle(descript…

HBox(children=(FloatProgress(value=0.0, description='split=train', max=1978.0, style=ProgressStyle(description…

HBox(children=(FloatProgress(value=0.0, description='split=val', max=847.0, style=ProgressStyle(description_wi…

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1024.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 512.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 256.0
Exiting loop
