https://arxiv.org/pdf/1810.04805.pdf

In [1]:
import os
os.sys.path.append('..')

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import collections
import logging
import json
import math
import os
import random
import six
from tqdm import tqdm_notebook as tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

import tokenization
from modeling import BertConfig, BertForLanguageModelling
from optimization import BERTAdam


In [4]:

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s', 
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

# Args

In [5]:
parser = argparse.ArgumentParser()

## Required parameters
parser.add_argument("--data_dir",
                    default=None,
                    type=str,
                    required=True,
                    help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--bert_config_file",
                    default=None,
                    type=str,
                    required=True,
                    help="The config json file corresponding to the pre-trained BERT model. \n"
                         "This specifies the model architecture.")
parser.add_argument("--task_name",
                    default=None,
                    type=str,
                    required=True,
                    help="The name of the task to train.")
parser.add_argument("--vocab_file",
                    default=None,
                    type=str,
                    required=True,
                    help="The vocabulary file that the BERT model was trained on.")
parser.add_argument("--output_dir",
                    default=None,
                    type=str,
                    required=True,
                    help="The output directory where the model checkpoints will be written.")

## Other parameters
parser.add_argument("--init_checkpoint",
                    default=None,
                    type=str,
                    help="Initial checkpoint (usually from a pre-trained BERT model).")
parser.add_argument("--do_lower_case",
                    default=False,
                    action='store_true',
                    help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--max_seq_length",
                    default=128,
                    type=int,
                    help="The maximum total input sequence length after WordPiece tokenization. \n"
                         "Sequences longer than this will be truncated, and sequences shorter \n"
                         "than this will be padded.")
parser.add_argument("--do_train",
                    default=False,
                    action='store_true',
                    help="Whether to run training.")
parser.add_argument("--do_eval",
                    default=False,
                    action='store_true',
                    help="Whether to run eval on the dev set.")
parser.add_argument("--train_batch_size",
                    default=32,
                    type=int,
                    help="Total batch size for training.")
parser.add_argument("--eval_batch_size",
                    default=8,
                    type=int,
                    help="Total batch size for eval.")
parser.add_argument("--learning_rate",
                    default=5e-5,
                    type=float,
                    help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs",
                    default=3.0,
                    type=float,
                    help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion",
                    default=0.1,
                    type=float,
                    help="Proportion of training to perform linear learning rate warmup for. "
                         "E.g., 0.1 = 10%% of training.")
parser.add_argument("--save_checkpoints_steps",
                    default=1000,
                    type=int,
                    help="How often to save the model checkpoint.")
parser.add_argument("--no_cuda",
                    default=False,
                    action='store_true',
                    help="Whether not to use CUDA when available")
parser.add_argument("--local_rank",
                    type=int,
                    default=-1,
                    help="local_rank for distributed training on gpus")
parser.add_argument('--seed', 
                    type=int, 
                    default=42,
                    help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
                    type=int,
                    default=1,
                    help="Number of updates steps to accumualte before performing a backward/update pass.") 

_StoreAction(option_strings=['--gradient_accumulation_steps'], dest='gradient_accumulation_steps', nargs=None, const=None, default=1, type=<class 'int'>, choices=None, help='Number of updates steps to accumualte before performing a backward/update pass.', metavar=None)

- BERT_BASE_DIR/* is from pretrained model
- train-v1.1.json etc are training data e.g. https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json


In [6]:
experiment_name = 'scratch_3'

argv = """
--data_dir {DATA_DIR} \
--bert_config_file {BERT_BASE_DIR}/bert_config.json \
--task_name lm \
--vocab_file {BERT_BASE_DIR}/vocab.txt \
--init_checkpoint {BERT_BASE_DIR}/pytorch_model.bin \
--do_train \
--do_eval \
--do_lower_case \
--gradient_accumulation_steps 2 \
--train_batch_size 24 \
--learning_rate 3e-5 \
--num_train_epochs 2.0 \
--max_seq_length 128 \
--output_dir ../outputs/{name}/
""".format(
    BERT_BASE_DIR='../data/weights/cased_L-12_H-768_A-12',
    DATA_DIR='../data/input/erotic_gutenberg',
    name=experiment_name
).replace('\n', '').split(' ')
print(argv)
args = parser.parse_args(argv)

['--data_dir', '../data/input/erotic_gutenberg', '--bert_config_file', '../data/weights/cased_L-12_H-768_A-12/bert_config.json', '--task_name', 'lm', '--vocab_file', '../data/weights/cased_L-12_H-768_A-12/vocab.txt', '--init_checkpoint', '../data/weights/cased_L-12_H-768_A-12/pytorch_model.bin', '--do_train', '--do_eval', '--do_lower_case', '--gradient_accumulation_steps', '2', '--train_batch_size', '24', '--learning_rate', '3e-5', '--num_train_epochs', '2.0', '--max_seq_length', '128', '--output_dir', '../outputs/scratch_3/']


# Init

In [7]:
if args.local_rank == -1 or args.no_cuda:
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()
else:
    device = torch.device("cuda", args.local_rank)
    n_gpu = 1
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend='nccl')
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))

if args.gradient_accumulation_steps < 1:
    raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                        args.gradient_accumulation_steps))

args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu > 0:
    torch.cuda.manual_seed_all(args.seed)

if not args.do_train and not args.do_eval:
    raise ValueError("At least one of `do_train` or `do_eval` must be True.")

11/10/2018 16:06:16 - INFO - __main__ -   device cuda n_gpu 1 distributed training False


In [8]:
bert_config = BertConfig.from_json_file(args.bert_config_file)

if args.max_seq_length > bert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format(
        args.max_seq_length, bert_config.max_position_embeddings))

In [9]:
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
    print("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True)

Output directory (../outputs/scratch_3/) already exists and is not empty.


In [10]:
save_path = os.path.join(args.output_dir, 'state_dict.pkl')

# Load data v2

## Helpers

In [11]:
# from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/run_classifier.py

class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.
        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()
        

class LMProcessor(DataProcessor):
    """Processor for language modelling."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            open(os.path.join(data_dir, "train.txt")).read(), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            open(os.path.join(data_dir, "val.txt")).read(), "dev")

    def get_labels(self):
        """See base class."""
        return list(tokenizer.vocab.keys())

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        tokens = []
        for line in tqdm(lines.split('\n\n'), desc='tokenising'):
            line = tokenization.convert_to_unicode(line)
            token = tokenizer.tokenize(line)
            tokens += token
        
        context_list = []
        target_list = []
        examples = []
        for i, start_idx in tqdm(list(enumerate(range(len(tokens) - window_size - 1))), desc='chunking'):
            guid = "%s-%s" % (set_type, i)
            context = tokens[start_idx : start_idx + window_size]
            target = tokens[start_idx + window_size]
            examples.append(
                InputExample(guid=guid, text_a=context, text_b=None, label=target))
            
        return examples



In [12]:

def convert_tokens_to_features(examples, label_list, max_seq_length, tokenizer):
    """Loads a data file into a list of `InputBatch`s."""

    label_map = {}
    for (i, label) in enumerate(label_list):
        label_map[label] = i

    features = []
    for (ex_index, example) in tqdm(list(enumerate(examples))):
        tokens_a = example.text_a

        tokens_b = None
        if example.text_b:
            tokens_b = example.text_b

        if tokens_b:
            # Modifies `tokens_a` and `tokens_b` in place so that the total
            # length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
        else:
            # Account for [CLS] and [SEP] with "- 2"
            if len(tokens_a) > max_seq_length - 2:
                tokens_a = tokens_a[0:(max_seq_length - 2)]

        # The convention in BERT is:
        # (a) For sequence pairs:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
        # (b) For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids: 0   0   0   0  0     0 0
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambigiously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For classification tasks, the first vector (corresponding to [CLS]) is
        # used as as the "sentence vector". Note that this only makes sense because
        # the entire model is fine-tuned.
        tokens = []
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append("[SEP]")
        segment_ids.append(0)

        if tokens_b:
            for token in tokens_b:
                tokens.append(token)
                segment_ids.append(1)
            tokens.append("[SEP]")
            segment_ids.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        label_id = label_map[example.label]
        if ex_index < 3:
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            logger.info("tokens: %s" % " ".join(
                    [tokenization.printable_text(x) for x in tokens]))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
            logger.info(
                    "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
            logger.info("label: %s (id = %d)\n" % (example.label, label_id))

        features.append(
                InputFeatures(
                        input_ids=input_ids,
                        input_mask=input_mask,
                        segment_ids=segment_ids,
                        label_id=label_id))
    return features


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

## Load

In [13]:
tokenizer = tokenization.FullTokenizer(
    vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)

decoder = {v:k for k,v in tokenizer.wordpiece_tokenizer.vocab.items()}

In [14]:
window_size                        = 128

In [15]:
processors = {
        "lm": LMProcessor,
}
    
task_name = args.task_name.lower()
if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

processor = processors[task_name]()
label_list = processor.get_labels()

In [16]:
# generation_parameters = {
#     'model'        : model,
#     'text_encoder' : tokenizer,
#     'sentence'     : 'You had a great morning but your afternoon will be ruined because',
#     'window_size'  : window_size,
# #     'n_vocab'      : n_vocab,
#     'n_special'    : n_special,
#     'n_ctx'        : n_ctx,
#     'device'       : device,
#     'final_len'    : 150
# }

In [None]:
train_examples = processor.get_train_examples(args.data_dir)
num_train_steps = int(
    len(train_examples) / args.train_batch_size * args.num_train_epochs)

HBox(children=(IntProgress(value=0, description='tokenising', max=29232), HTML(value='')))

In [None]:
train_features = convert_tokens_to_features(
    train_examples, label_list, args.max_seq_length, tokenizer)

all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)

In [None]:
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
    train_sampler = RandomSampler(train_data)
else:
    train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

In [None]:
# epochs = int(args.num_train_epochs)
# n_batch_train = args.train_batch_size

In [None]:
# n_train                     = len(y_train)
# n_valid                     = len(y_val) // 10
# n_updates_total             = (n_train // args.train_batch_size) * epochs

# Load model

In [None]:
model = BertForLanguageModelling(bert_config)
if args.init_checkpoint is not None:
    model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
    
if os.path.isfile(save_path):
    model.load_state_dict(torch.load(save_path, map_location='cpu'))
    
model.to(device)

if args.local_rank != -1:
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                      output_device=args.local_rank)
elif n_gpu > 1:
    model = torch.nn.DataParallel(model)
    
model

# Opt

In [None]:
no_decay = ['bias', 'gamma', 'beta']
optimizer_parameters = [
    {'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
    {'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
    ]

optimizer = BERTAdam(optimizer_parameters,
                     lr=args.learning_rate,
                     warmup=args.warmup_proportion,
                t_total=num_train_steps)

# Train 3

In [None]:
# TODO, gen data on the fly? 
# like iter_data in https://github.com/wassname/openai-transformer-lm-gutenberg-erotic/blob/master/utils.py

In [None]:
model.train()
for _ in tqdm(range(int(args.num_train_epochs)), desc="Epoch"):
    tr_loss = 0
    losses = []
    nb_tr_examples, nb_tr_steps = 0, 0
    with tqdm(total=len(train_dataloader), desc='Iteration', mininterval=0.5) as prog:
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
            if n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()
            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()    # We have accumulated enougth gradients
                model.zero_grad()
            prog.update(1)
            prog.desc = 'Iter. loss={:2.6f}'.format(tr_loss/nb_tr_examples)
            if step%1000==0:
                print(step, tr_loss/nb_tr_examples)
                # TODO show running mean
    
    torch.save(model.state_dict(), save_path)
            
global_step += 1

- TODO pred generator
- TODO sample outputs
- TODO check it works
- TODO isn't it a bit slow?
- TODO should I mirror the embedding weights?

In [None]:
torch.save(model.state_dict(), save_path)
save_path

In [None]:
torch.cuda.empty_cache()

# Predict 1

In [None]:
with torch.no_grad():
    batch = next(iter(train_dataloader))
    batch = tuple(t.to(device) for t in batch)
    input_ids, input_mask, segment_ids, label_ids = batch
    outs = model(input_ids, segment_ids, input_mask).detach()
    y = torch.nn.functional.log_softmax(outs).argmax(-1).cpu().numpy()

In [None]:
replace_list = [
    [' ##', ''],
    ['## ', ''],
    ['[CLS] ', ''],
    [' [SEP]', ''],
    [' ,', ','],
    [' .', '.'],
    [' :', ':'],
    [' ;', ';'],
    [' !', '!'],
    [' ?', '?'],
    [" ' ", "'"],
    [" '", "'"],
    [" - ", "-"],
    ["“ ", "“"],
    [" ”", "”"],
    ["’ ", "’"],
    [" ’", "’"],
]



In [None]:
for i in range(len(input_ids)):
    x = [decoder[yy] for yy in input_ids[i].cpu().numpy()]
    o = decoder[y[i]]
    x = ' '.join(x) + o
    for a, b in replace_list:
        x = x.replace(a, b)
    print('`{}`: `{}`\n'.format(x, o))

In [None]:
x

In [None]:
y = word_decode(outs[1][0])
y = torch.nn.functional.log_softmax(y).argmax(-1).cpu().item()
decoder[y]

In [None]:
y = tok_decode(outs[1][0])
y = torch.nn.functional.log_softmax(y).argmax(-1).cpu().item()
decoder[y]

In [None]:
y = pos_decode(outs[1][0])
y = torch.nn.functional.log_softmax(y).argmax(-1).cpu().item()
decoder[y]

In [None]:

tok_text = decoder[y]

# De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")

# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
tok_text

In [None]:
def try_on_a_sentence(model, text_encoder, sentence, window_size,
                      n_vocab, n_special, n_ctx, device,
                      final_len = 200, temperature=1.0):
    model.eval()
    start_token  = text_encoder.encoder['_start_']
    clf_token    = text_encoder.encoder['_classify_']
    encoded_text = text_encoder.encode([sentence])[0]
    with tqdm(unit='word', total=final_len) as prog:
        while len(encoded_text) < final_len:
            # We take the last 'window_size' words of the text being generated
            # and run it through the model.
            context         = encoded_text[-window_size:]
            X_trans, X_mask = transform_dataset(
                [context],
                text_encoder,
                window_size,
                n_vocab,
                n_special,
                n_ctx
            )
            XMB                = torch.tensor(X_trans, dtype = torch.long).to(device)
            lm_logits          = model(XMB)

            # We truncate the resulting predictions to actual vocabulary
            # words in order to exclude special tokens and positional
            # embeddings.
            lm_logits          = lm_logits[:, : n_vocab]/temperature
            # Higher temperature mean all actions have the same probability. At low ones they are more deterministic.

            # We then select the logit corresponding to the 'clf_token'
            # position (last one of the sequence).
            X_trans_tensor     = torch.from_numpy(X_trans)
            clf_token_bool_idx = X_trans_tensor[0, :, 0] == clf_token

            # probabilistic sample so we don't get into loops
            predictions = torch.distributions.Multinomial(logits=lm_logits).sample().argmax(dim = 1)
            pred               = predictions[clf_token_bool_idx[1:]].item()
            encoded_text.append(pred)
            prog.update(1)
        prog.close()

    print(len(encoded_text), final_len)
    return decode_sentence(text_encoder, encoded_text)

# Predict TODO

In [None]:
if args.do_predict:
    eval_examples = read_squad_examples(
        input_file=args.predict_file, is_training=False)
    eval_features = convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        doc_stride=args.doc_stride,
        max_query_length=args.max_query_length,
        is_training=False)

    logger.info("***** Running predictions *****")
    logger.info("  Num orig examples = %d", len(eval_examples))
    logger.info("  Num split examples = %d", len(eval_features))
    logger.info("  Batch size = %d", args.predict_batch_size)

    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)

    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
    if args.local_rank == -1:
        eval_sampler = SequentialSampler(eval_data)
    else:
        eval_sampler = DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)

    model.eval()
    all_results = []
    logger.info("Start evaluating")
    for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
        if len(all_results) % 1000 == 0:
            logger.info("Processing example: %d" % (len(all_results)))
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        with torch.no_grad():
            batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
        for i, example_index in enumerate(example_indices):
#             start_logits = batch_start_logits[i].detach().cpu().tolist()
            end_logits = batch_end_logits[i].detach().cpu().tolist()
            eval_feature = eval_features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            all_results.append(RawResult(unique_id=unique_id,
                                         start_logits=start_logits,
                                         end_logits=end_logits))
    output_prediction_file = os.path.join(args.output_dir, "predictions.json")
    output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
    write_predictions(eval_examples, eval_features, all_results,
                      args.n_best_size, args.max_answer_length,
                      args.do_lower_case, output_prediction_file,
                      output_nbest_file, args.verbose_logging)


