https://arxiv.org/pdf/1810.04805.pdf

In [1]:
import os
os.sys.path.append('..')

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import collections
import logging
import json
import math
import os
import random
import six
from tqdm import tqdm_notebook as tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

import tokenization
from modeling import BertConfig, BertForMaskedLanguageModelling
from optimization import BERTAdam


In [4]:

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s', 
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

# Args

In [5]:
parser = argparse.ArgumentParser()

## Required parameters
parser.add_argument("--data_dir",
                    default=None,
                    type=str,
                    required=True,
                    help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--bert_config_file",
                    default=None,
                    type=str,
                    required=True,
                    help="The config json file corresponding to the pre-trained BERT model. \n"
                         "This specifies the model architecture.")
parser.add_argument("--task_name",
                    default=None,
                    type=str,
                    required=True,
                    help="The name of the task to train.")
parser.add_argument("--vocab_file",
                    default=None,
                    type=str,
                    required=True,
                    help="The vocabulary file that the BERT model was trained on.")
parser.add_argument("--output_dir",
                    default=None,
                    type=str,
                    required=True,
                    help="The output directory where the model checkpoints will be written.")

## Other parameters
parser.add_argument("--init_checkpoint",
                    default=None,
                    type=str,
                    help="Initial checkpoint (usually from a pre-trained BERT model).")
parser.add_argument("--do_lower_case",
                    default=False,
                    action='store_true',
                    help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--max_seq_length",
                    default=128,
                    type=int,
                    help="The maximum total input sequence length after WordPiece tokenization. \n"
                         "Sequences longer than this will be truncated, and sequences shorter \n"
                         "than this will be padded.")
parser.add_argument("--do_train",
                    default=False,
                    action='store_true',
                    help="Whether to run training.")
parser.add_argument("--do_eval",
                    default=False,
                    action='store_true',
                    help="Whether to run eval on the dev set.")
parser.add_argument("--train_batch_size",
                    default=32,
                    type=int,
                    help="Total batch size for training.")
parser.add_argument("--eval_batch_size",
                    default=8,
                    type=int,
                    help="Total batch size for eval.")
parser.add_argument("--learning_rate",
                    default=5e-5,
                    type=float,
                    help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs",
                    default=3.0,
                    type=float,
                    help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion",
                    default=0.1,
                    type=float,
                    help="Proportion of training to perform linear learning rate warmup for. "
                         "E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda",
                    default=False,
                    action='store_true',
                    help="Whether not to use CUDA when available")
parser.add_argument("--local_rank",
                    type=int,
                    default=-1,
                    help="local_rank for distributed training on gpus")
parser.add_argument('--seed', 
                    type=int, 
                    default=42,
                    help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
                    type=int,
                    default=1,
                    help="Number of updates steps to accumualte before performing a backward/update pass.") 

_StoreAction(option_strings=['--gradient_accumulation_steps'], dest='gradient_accumulation_steps', nargs=None, const=None, default=1, type=<class 'int'>, choices=None, help='Number of updates steps to accumualte before performing a backward/update pass.', metavar=None)

In [6]:
experiment_name = 'horror_uncased_4_tied_mlm'

argv = """
--task_name lm \
--data_dir {DATA_DIR} \
--vocab_file {BERT_BASE_DIR}/vocab.txt \
--bert_config_file {BERT_BASE_DIR}/bert_config.json \
--init_checkpoint {BERT_BASE_DIR}/pytorch_model.bin \
--do_train \
--do_eval \
--gradient_accumulation_steps 2 \
--train_batch_size 24 \
--learning_rate 3e-5 \
--num_train_epochs 3.0 \
--max_seq_length 128 \
--output_dir ../outputs/{name}/
""".format(
    BERT_BASE_DIR='../data/weights/cased_L-12_H-768_A-12',
    DATA_DIR='../data/input/horror_gutenberg',
    name=experiment_name
).replace('\n', '').split(' ')
print(argv)
args = parser.parse_args(argv)

['--task_name', 'lm', '--data_dir', '../data/input/horror_gutenberg', '--vocab_file', '../data/weights/cased_L-12_H-768_A-12/vocab.txt', '--bert_config_file', '../data/weights/cased_L-12_H-768_A-12/bert_config.json', '--init_checkpoint', '../data/weights/cased_L-12_H-768_A-12/pytorch_model.bin', '--do_train', '--do_eval', '--gradient_accumulation_steps', '2', '--train_batch_size', '24', '--learning_rate', '3e-5', '--num_train_epochs', '3.0', '--max_seq_length', '128', '--output_dir', '../outputs/horror_uncased_4_tied_mlm/']


# Init

In [7]:
if args.local_rank == -1 or args.no_cuda:
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()
else:
    device = torch.device("cuda", args.local_rank)
    n_gpu = 1
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend='nccl')
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))

if args.gradient_accumulation_steps < 1:
    raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                        args.gradient_accumulation_steps))

args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu > 0:
    torch.cuda.manual_seed_all(args.seed)

if not args.do_train and not args.do_eval:
    raise ValueError("At least one of `do_train` or `do_eval` must be True.")

11/12/2018 06:15:27 - INFO - __main__ -   device cuda n_gpu 1 distributed training False


In [8]:
bert_config = BertConfig.from_json_file(args.bert_config_file)

if args.max_seq_length > bert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format(
        args.max_seq_length, bert_config.max_position_embeddings))

In [9]:
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
    print("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True)

Output directory (../outputs/horror_uncased_4_tied_mlm/) already exists and is not empty.


In [10]:
save_path = os.path.join(args.output_dir, 'state_dict.pkl')

# Load data v2

## Helpers

In [11]:
def notqdm(it, *a, **k):
    return it

In [262]:
# from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/run_classifier.py

class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.
        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id, label_weights):
        self.input_ids = input_ids # inputs tokens with 103 for mask
        self.input_mask = input_mask # 0 for padding, 1 otherwise
        self.segment_ids = segment_ids # which sentance it's in
        self.label_id = label_id # labels, true tokens
        self.label_weights = label_weights


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()
        

class LMProcessor(DataProcessor):
    """Processor for language modelling."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            open(os.path.join(data_dir, "train.txt")).read(), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            open(os.path.join(data_dir, "val.txt")).read(), "dev")

    def get_labels(self):
        """See base class."""
        return list(tokenizer.vocab.keys())

    def _create_examples(self, lines, set_type, window_size=args.max_seq_length, tqdm=tqdm):
        """Creates examples for the training and dev sets."""
        tokens = []
        for line in tqdm(lines.split('\n\n'), desc='tokenising'):
            line = tokenization.convert_to_unicode(line)
            token = tokenizer.tokenize(line)
            tokens += token            
        
        examples = []
        for i, start_idx in tqdm(list(enumerate(range(len(tokens) - window_size - 1))), desc='chunking'):
            guid = "%s-%s" % (set_type, i)
            text_a = tokens[start_idx : start_idx + window_size]
            label = tokens[start_idx + window_size]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
            
        if len(examples)==0:
            guid = "%s-%s" % (set_type, 0)
            text_a = tokens[:-1]
            label = tokens[-1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
            
        return examples



In [263]:

def convert_tokens_to_features(examples, label_list, max_seq_length, tokenizer, tqdm=tqdm):
    """Loads a data file into a list of `InputBatch`s."""

    label_map = {}
    for (i, label) in enumerate(label_list):
        label_map[label] = i

    features = []
    for (ex_index, example) in tqdm(list(enumerate(examples))):
        tokens_a = example.text_a

        tokens_b = None
        if example.text_b:
            tokens_b = example.text_b

        if tokens_b:
            # Modifies `tokens_a` and `tokens_b` in place so that the total
            # length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
        else:
            # Account for [CLS] and [SEP] with "- 2"
            if len(tokens_a) > max_seq_length - 2:
                tokens_a = tokens_a[0:(max_seq_length - 2)]

        # The convention in BERT is:
        # (a) For sequence pairs:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
        # (b) For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids: 0   0   0   0  0     0 0
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambigiously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For classification tasks, the first vector (corresponding to [CLS]) is
        # used as as the "sentence vector". Note that this only makes sense because
        # the entire model is fine-tuned.
        tokens = []
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append("[SEP]")
        segment_ids.append(0)

        if tokens_b:
            for token in tokens_b:
                tokens.append(token)
                segment_ids.append(1)
            tokens.append("[SEP]")
            segment_ids.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
        
        # see https://github.com/google-research/bert/blob/d8014ef72/create_pretraining_data.py#L363
        label_weights = np.random.rand(len(input_ids))<0.10 # 10% change of masking
        label_keep = (np.random.rand(len(input_ids))<0.10) * label_weights # 10% chance of keeping
        label_switch = (np.random.rand(len(input_ids))<0.10) * label_weights # 10% chance of random word
        label_mask = label_weights * (1-label_keep) * (1-label_keep)
        
        switched_ids = np.random.randint(low=0, high=len(tokenizer.vocab)-1, size=(len(input_ids),))
        
        input_ids_masked = np.array(input_ids.copy())
        input_ids_masked[label_switch==1] = switched_ids[label_switch==1]
        input_ids_masked[label_mask==1] = 103
        input_ids_masked = input_ids_masked.tolist()
        
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(input_ids_masked) == max_seq_length
        
#         label_weights[torch.Tensor(input_ids)<103] = 0 # Don't mask CLS, PAD, ETC
#         input_ids_masked = input_ids.copy()
#         for i in range(len(label_weights)):
#             if label_weights[i]:
#                 if np.random.rand() < 0.8:
#                     input_ids_masked[i] = 103
#                 else:
#                     if np.random.rand() < 0.5:
#                         input_ids_masked[i] = np.random.randint(0, len(tokenizer.vocab)-1)
        
#         input_ids_masked[lms_masked==1] = 103
        # TODO 1% of the time replace masked ones with random work
        # TODO 1% of the time keep old one

#         if ex_index < 3:
#             logger.info("*** Example ***")
#             logger.info("guid: %s" % (example.guid))
#             logger.info("tokens: %s" % " ".join(
#                     [tokenization.printable_text(x) for x in tokens]))
#             logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
#             logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
#             logger.info(
#                     "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
#             logger.info("label: %s (id = %d)\n" % (example.label, label_id))

        features.append(
                InputFeatures(
                        input_ids=input_ids_masked,
                        input_mask=input_mask,
                        segment_ids=segment_ids,
                        label_id=input_ids,
                        label_weights=label_weights,))
    return features


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

## Load

In [14]:
tokenizer = tokenization.FullTokenizer(
    vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)

decoder = {v:k for k,v in tokenizer.wordpiece_tokenizer.vocab.items()}

In [15]:
# window_size                        = 128


In [283]:
processors = {
        "lm": LMProcessor,
}
    
task_name = args.task_name.lower()
if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

processor = processors[task_name]()
label_list = processor.get_labels()

In [17]:
train_examples = processor.get_train_examples(args.data_dir)
num_train_steps = int(
    len(train_examples) / args.train_batch_size * args.num_train_epochs)

HBox(children=(IntProgress(value=0, description='tokenising', max=53852), HTML(value='')))




HBox(children=(IntProgress(value=0, description='chunking', max=4480426), HTML(value='')))




In [18]:
train_features = convert_tokens_to_features(
    train_examples[::25], label_list, args.max_seq_length, tokenizer)

all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
all_label_weights = torch.tensor([f.label_weights for f in train_features], dtype=torch.long)

HBox(children=(IntProgress(value=0, max=179218), HTML(value='')))




In [19]:
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_label_weights)
if args.local_rank == -1:
    train_sampler = RandomSampler(train_data)
else:
    train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

# Load model

In [20]:
model = BertForMaskedLanguageModelling(bert_config)
if args.init_checkpoint is not None:
    model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
    
if os.path.isfile(save_path):
    model.load_state_dict(torch.load(save_path, map_location='cpu'))
    
model.to(device)

if args.local_rank != -1:
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                      output_device=args.local_rank)
elif n_gpu > 1:
    model = torch.nn.DataParallel(model)
    
model

BertForMaskedLanguageModelling(
  (bert): BertModel(
    (embeddings): BERTEmbeddings(
      (word_embeddings): Embedding(28996, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BERTLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (encoder): BERTEncoder(
      (layer): ModuleList(
        (0): BERTLayer(
          (attention): BERTAttention(
            (self): BERTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BERTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BERTLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BERTInterm

# Opt

In [21]:
no_decay = ['bias', 'gamma', 'beta']
optimizer_parameters = [
    {'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
    {'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
    ]

optimizer = BERTAdam(optimizer_parameters,
                     lr=args.learning_rate,
                     warmup=args.warmup_proportion,
                t_total=num_train_steps)

# Predict helpers

In [312]:
import re
from IPython.display import HTML, display

b4='(\<span[^\>]+\>)?'
after='(\<\/span\>)?'

replace_list = [
    # punctuation
    ['\s{}\,'.format(b4), r'\1,'],
    ['\s{}\.'.format(b4), r'\1.'],
    ['\s{}\:'.format(b4), r'\1:'],
    ['\s{}\;'.format(b4), r'\1;'],
    ['\s{}\!'.format(b4), r'\1!'],
    ['\s{}\?'.format(b4), r'\1?'],
    ["\s{}\'{}\s?".format(b4,after), r"\1'\2"],
    ["\s{}\-{}\s".format(b4,after), r"\1-\2"],
    ["\“{}\s".format(after), r"“\1"],
    ["\s{}\”".format(b4), r"\1”"],
    ["\s{}\’".format(b4), r"\1’"],
    
    # tokenization
    ['\s?{}\#\#'.format(b4), r'\1'],    
#     ['\[CLS\]\s?', ''],
#     ['\s?\[SEP\]\s?', ''],
    ["\[\s?PAD\s?\]\s?", r""],
    ["\s?\¿\s?", r""],
#     [UNK]
#     [MASK]
#     [CLS]
    # TODO ideally I need to be able to do these with of without span tags
]

def clean_decoded(tokens):
    s = ' '.join(tokens)
    for a, b in replace_list:
        p = re.search(a, s)
        s = re.sub(a, b, s)
    return s


In [313]:
val_test="""The authorities at Scotland Yard are unable to suggest any explanation of these terrible occurrences."""
display(predict_next_words(val_test, processor, tokenizer, n=10, T=.5))

¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ ¿ The authorities at Scotland Yard are unable to suggest any explanation of these terrible occurrences.
['¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', '¿', 'The', 'authorities', 'at', 'Scotland', 'Yard', 'are', 'unable', 'to', 'suggest', 'any', 'explanation', 'o

In [347]:

def html_clean_decoded_logits(input_ids, logits, input_mask, label_weights):
    """Format model outputs as html, with masked elements in red, with opacity indicating confidence."""
    log_probs = nn.LogSoftmax(-1)(logits).detach()
    prediction_idxs = log_probs.argmax(-1)
    # join masked an non masked
    y = input_ids *  (1 - label_weights) + prediction_idxs * label_weights
    yd = [decoder[hh.item()] for hh in y]
    html_yd = []
    for i in range(len(yd)):
        if not label_weights[i]:
            html_yd.append(yd[i])
        else:
            prob = log_probs[i][prediction_idxs[i]].exp()
            prob = prob/2 + 0.5
            html_yd.append('<span style="color: rgba(255,0,0,{})">{}</span>'.format(prob, yd[i]))
    return clean_decoded(html_yd)

def html_clean_decoded(tokens, input_mask, label_weights):
    """Format model outputs as html, with masked elements in red, with opacity indicating confidence."""
    yd = [decoder[hh.item()] for hh in tokens]
    html_yd = []
    for i in range(len(yd)):
        if not label_weights[i]:
            html_yd.append(yd[i])
        else:
            prob = 1
            html_yd.append('<span style="color: rgba(255,0,0,{})">{}</span>'.format(prob, yd[i]))
    return clean_decoded(html_yd)

In [368]:



def predict_masked_words(x, processor, tokenizer, n=10):
    ex = processor._create_examples(x, "train", tqdm=notqdm)[-1:]
    label_list = processor.get_labels()

    log_feats = convert_tokens_to_features(ex, label_list, args.max_seq_length, tokenizer, tqdm=notqdm)

    log_input_ids = torch.tensor([f.input_ids for f in log_feats], dtype=torch.long)
    log_input_mask = torch.tensor([f.input_mask for f in log_feats], dtype=torch.long)
    log_segment_ids = torch.tensor([f.segment_ids for f in log_feats], dtype=torch.long)
    log_label_ids = torch.tensor([f.label_id for f in log_feats], dtype=torch.long)
    log_label_weights = torch.tensor([f.label_weights for f in log_feats], dtype=torch.long)

    batch = [log_input_ids, log_input_mask, log_segment_ids, log_label_ids, log_label_weights]

    with torch.no_grad():
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids, label_weights = batch
        logits = model(input_ids, segment_ids, input_mask).detach()

    i = 0
    display(HTML(html_clean_decoded(tokens=log_label_ids[i][1:-2], input_mask=input_mask[i][1:-2], label_weights=label_weights[i][1:-2]).replace('rgba(255,0,0', 'rgba(0,0,255')))
    display(HTML(html_clean_decoded_logits(input_ids=input_ids[i][1:-1], input_mask=input_mask[i][1:-1], logits=logits[i][1:-1], label_weights=label_weights[i][1:-1])))


In [481]:
# def predictions
def pad_seq(s1):
    # HACK: pad short sentances with
    x = '¿ ' * (args.max_seq_length + 2 - len(tokenizer.tokenize(s1))) + s1
    return x

def predict_next_words(text, processor, tokenizer, n=10, T=1.0):
    """
    Predict next `n` words for some `text`
    Args:
    - text (str) base string, we will predict next words
    - processor
    - tokenizer
    - n (int) amount of words to predict
    - T (float) temperature for when samping predictions
    
    Returns:
    - IPython html object, which show predicted words in red, with opacity indicating confidence
    """
    discarded = []
    x= pad_seq(text)
    ex = processor._create_examples(x, "train", tqdm=notqdm)[-1:]
    label_list = processor.get_labels()

    log_feats = convert_tokens_to_features(ex, label_list, args.max_seq_length, tokenizer, tqdm=notqdm)

    with torch.no_grad():

        log_input_ids = torch.tensor([f.input_ids for f in log_feats], dtype=torch.long)
        log_input_mask = torch.tensor([f.input_mask for f in log_feats], dtype=torch.long)
        log_segment_ids = torch.tensor([f.segment_ids for f in log_feats], dtype=torch.long)
        log_label_ids = torch.tensor([f.label_id for f in log_feats], dtype=torch.long)
        log_label_weights = torch.tensor([f.label_weights for f in log_feats], dtype=torch.long)

        # Now we only want to predict the next word... so remove our masks
        log_input_ids = log_label_ids * 1
        log_label_weights[:]=0

        # and add a mask token 2nd to last, and drop the first word (to keep max seq len)
        discarded.append(log_input_ids[0, 1])
        log_input_ids = torch.cat([torch.tensor([[101]]), log_input_ids[:, 2:-1], torch.tensor([[103, 102]])], -1)
        log_input_mask = torch.cat([log_input_mask[:, 1:], torch.tensor([[1]])], -1) # Add one to end
        log_label_weights[:, -2] = 1

        batch = [log_input_ids, log_input_mask, log_segment_ids, log_label_ids, log_label_weights]
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids, label_weights = batch
        
        

        for i in range(n):
            logits = model(input_ids, segment_ids, input_mask).detach()
            
            
            
            
            # sample outputs with probability...
            predictions = torch.distributions.Multinomial(logits=logits/T).sample()
            next_word = predictions[:, -2].argmax(dim = -1)

            # Add prediction to end, and update data tensor by rolling the contents
            # drop first part of content, add prediction to end of content (and put sides back: CLS=101 at start, and MASK=103, SEP=102 at end again)
            discarded.append(input_ids[0, 1])
            input_ids = torch.cat([torch.tensor([[101]]).cuda(), input_ids[:, 2:-2], torch.tensor([[next_word, 103, 102]]).cuda()], -1) 
            input_mask = torch.cat([input_mask[:, 1:], torch.tensor([[1]]).cuda()], -1) # drop first, add [1] to end
            label_weights = torch.cat([label_weights[:, 1:-1], torch.tensor([[1, 0]]).cuda()], -1) # drop 1st, add 1 to end of content
            
#            # I could print probabilities
#             log_probs = F.log_softmax(logits, -1)
#             print(decoder[next_word.item()], 'prob={:2.4e}'.format(log_probs[0, -1, next_word.item()].exp().item()))

    # TODO add discarded start tokens
    input_ids = torch.cat([torch.tensor([discarded]).cuda(), input_ids[:, 2:-2]], -1) 
    input_mask = torch.cat([torch.tensor([[1]*len(discarded)]).cuda(), input_mask[:, 2:-2]], -1) # drop first, add [1] to end
    label_weights = torch.cat([torch.tensor([[0]*len(discarded)]).cuda(), label_weights[:, 2:-2]], -1) # drop 1st, add 1 to end of content            

    batch = 0
    # return html fragment, cleaned, but cut of the first and last two tokens which are [CLS] and [MASK][SEP]
    return HTML(html_clean_decoded(tokens=input_ids[batch], input_mask=input_mask[batch], label_weights=label_weights[batch]))

In [487]:
val_test="""Another gentleman has fallen a victim to the terrible epidemic of suicide which for the last month has prevailed in the West End. Mr. Sidney Crashaw, of Stoke House, Fulham, and King's Pomeroy, Devon, was found, after a prolonged search, hanging dead from the branch of a tree in his garden at one o'clock today. The deceased gentleman dined last night at the Carlton Club and seemed in his usual health and spirits. He left the club at about ten o'clock, and was seen walking leisurely up St. James's Street a little later. Subsequent to this his movements cannot be traced"""
display(predict_next_words(val_test, processor, tokenizer, n=30, T=1))

In [483]:
display(predict_masked_words(val_test, processor, tokenizer))

None

# Train 3

In [27]:
global_step = 0

In [488]:
model.train()
for _ in tqdm(range(int(args.num_train_epochs)), desc="Epoch"):
    tr_loss, nb_tr_examples, nb_tr_steps = 0, 0, 0
    with tqdm(total=len(train_dataloader), desc='Iteration', mininterval=0.5) as prog:
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids, label_weights = batch
            
            loss, logits = model(input_ids, segment_ids, input_mask, label_ids, label_weights)
            if n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()
            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()    # We have accumulated enougth gradients
                model.zero_grad()
            prog.update(1)
            prog.desc = 'Iter. loss={:2.6f}'.format(tr_loss/nb_tr_examples)
            if step%1000==10:
                
                print('step', step, 'loss', tr_loss/nb_tr_examples)
                display(predict_masked_words(val_test, processor, tokenizer))
                display(predict_next_words(val_test, processor, tokenizer, n=10))
                tr_loss, nb_tr_examples, nb_tr_steps = 0, 0, 0
                
    
    torch.save(model.state_dict(), save_path)

global_step += 1

HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Iteration', max=14935), HTML(value='')))

step 10 loss 0.13508770953525195


None

KeyboardInterrupt: 

In [372]:
torch.save(model.state_dict(), save_path)

In [373]:
val_test="""Another gentleman has fallen a victim to the terrible epidemic of suicide which for the last month has prevailed in the West End.  Mr. Sidney Crashaw, of Stoke House, Fulham, and King's Pomeroy, Devon, was found, after a prolonged search, hanging dead from the branch of a tree in his garden at one o'clock today.  The deceased gentleman dined last night at the Carlton Club and seemed in his usual health and spirits. He left the club at about ten o'clock, and was seen walking leisurely up St. James's Street a little later.  Subsequent to this his movements cannot be traced.  On the discovery of the body medical aid was at once summoned, but life had evidently been long extinct.  So far as is known, Mr. Crashaw had no trouble or anxiety of any kind.  This painful suicide, it will be remembered, is the fifth of the kind in the last month.  The authorities at Scotland Yard are unable to suggest any explanation of these terrible occurrences."""
display(predict_next_words(val_test, processor, tokenizer, n=100, T=.5))