### Runned on transformers vir env

In [94]:
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm_notebook

prefix = '../datasets/yelp_review_polarity_csv/'

In [95]:
train_df = pd.read_csv(prefix + 'train.csv', header=None)
train_df.head(10)

Unnamed: 0,0,1
0,1,"Unfortunately, the frustration of being Dr. Go..."
1,2,Been going to Dr. Goldberg for over 10 years. ...
2,1,I don't know what Dr. Goldberg was like before...
3,1,I'm writing this review to give you a heads up...
4,2,All the food is great here. But the best thing...
5,1,Wing sauce is like water. Pretty much a lot of...
6,1,Owning a driving range inside the city limits ...
7,1,This place is absolute garbage... Half of the...
8,2,Before I finally made it over to this range I ...
9,2,I drove by yesterday to get a sneak peak. It ...


In [96]:
test_df = pd.read_csv(prefix + 'test.csv', header=None)
test_df.head(10)

Unnamed: 0,0,1
0,2,"Contrary to other reviews, I have zero complai..."
1,1,Last summer I had an appointment to get new ti...
2,2,"Friendly staff, same starbucks fair you get an..."
3,1,The food is good. Unfortunately the service is...
4,2,Even when we didn't have a car Filene's Baseme...
5,2,"Picture Billy Joel's \""Piano Man\"" DOUBLED mix..."
6,1,Mediocre service. COLD food! Our food waited s...
7,1,Ok! Let me tell you about my bad experience fi...
8,1,I used to love D&B when it first opened in the...
9,2,"Like any Barnes & Noble, it has a nice comfy c..."


In [97]:
# First column (mod 2)
train_df[0] = (train_df[0] == 2).astype(int)
test_df[0] = (test_df[0] == 2).astype(int)

In [98]:
train_df = pd.DataFrame({
    'id':range(len(train_df)),
    'label':train_df[0],
    'alpha':['a']*train_df.shape[0],
    'text': train_df[1].replace(r'\n', ' ', regex=True)
})

train_df.head()

Unnamed: 0,id,label,alpha,text
0,0,0,a,"Unfortunately, the frustration of being Dr. Go..."
1,1,1,a,Been going to Dr. Goldberg for over 10 years. ...
2,2,0,a,I don't know what Dr. Goldberg was like before...
3,3,0,a,I'm writing this review to give you a heads up...
4,4,1,a,All the food is great here. But the best thing...


In [99]:
dev_df = pd.DataFrame({
    'id':range(len(test_df)),
    'label':test_df[0],
    'alpha':['a']*test_df.shape[0],
    'text': test_df[1].replace(r'\n', ' ', regex=True)
})

dev_df.head()

Unnamed: 0,id,label,alpha,text
0,0,1,a,"Contrary to other reviews, I have zero complai..."
1,1,0,a,Last summer I had an appointment to get new ti...
2,2,1,a,"Friendly staff, same starbucks fair you get an..."
3,3,0,a,The food is good. Unfortunately the service is...
4,4,1,a,Even when we didn't have a car Filene's Baseme...


In [100]:
train_df.to_csv(prefix + 'train.tsv', sep='\t', index=False, header=False)
dev_df.to_csv(prefix + 'dev.tsv', sep='\t', index=False, header=False)

In [101]:
from __future__ import absolute_import, division, print_function

import glob
import logging
import os
import random
import json
import math

import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
import random
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm_notebook, trange
from tensorboardX import SummaryWriter

from pytorch_transformers import (WEIGHTS_NAME, BertConfig, BertForSequenceClassification, BertTokenizer,
                                  XLMConfig, XLMForSequenceClassification, XLMTokenizer, 
                                  XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer,
                                  RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer)

from pytorch_transformers import AdamW, WarmupLinearSchedule

import sys
sys.path.insert(1, '../src')
from utils import (convert_examples_to_features,
                        output_modes, processors)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [102]:
args = {
    'data_dir': '../datasets/yelp_review_polarity_csv/',
    'model_type':  'bert',
    'model_name': 'bert-base-cased',
    'task_name': 'binary',
    'output_dir': 'outputs/',
    'cache_dir': 'cache/',
    'do_train': True,
    'do_eval': True,
    'fp16': False,
    'fp16_opt_level': 'O1',
    'max_seq_length': 128,
    'output_mode': 'classification',
    'train_batch_size': 8,
    'eval_batch_size': 8,

    'gradient_accumulation_steps': 1,
    'num_train_epochs': 1,
    'weight_decay': 0,
    'learning_rate': 4e-5,
    'adam_epsilon': 1e-8,
    'warmup_steps': 0,
    'max_grad_norm': 1.0,

    'logging_steps': 50,
    'evaluate_during_training': False,
    'save_steps': 2000,
    'eval_all_checkpoints': True,

    'overwrite_output_dir': False,
    'reprocess_input_data': True,
    'notes': 'Using Yelp Reviews dataset'
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [103]:
with open('args.json', 'w') as f:
    json.dump(args, f)

In [104]:
if os.path.exists(args['output_dir']) and os.listdir(args['output_dir']) and args['do_train'] and not args['overwrite_output_dir']:
    raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args['output_dir']))

In [105]:
MODEL_CLASSES = {
    'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
    'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
    'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
    'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer)
}

config_class, model_class, tokenizer_class = MODEL_CLASSES[args['model_type']]

In [106]:
config = config_class.from_pretrained(args['model_name'], num_labels=2, finetuning_task=args['task_name'])
tokenizer = tokenizer_class.from_pretrained(args['model_name'])

INFO:pytorch_transformers.file_utils:https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json not found in cache or force_download set to True, downloading to C:\Users\berna\AppData\Local\Temp\tmp0ey4qcut



100%|██████████| 361/361 [00:00<00:00, 362148.71B/s]
INFO:pytorch_transformers.file_utils:copying C:\Users\berna\AppData\Local\Temp\tmp0ey4qcut to cache at C:\Users\berna\.cache\torch\pytorch_transformers\b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.3d5adf10d3445c36ce131f4c6416aa62e9b58e1af56b97664773f4858a46286e
INFO:pytorch_transformers.file_utils:creating metadata file for C:\Users\berna\.cache\torch\pytorch_transformers\b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.3d5adf10d3445c36ce131f4c6416aa62e9b58e1af56b97664773f4858a46286e
INFO:pytorch_transformers.file_utils:removing temp file C:\Users\berna\AppData\Local\Temp\tmp0ey4qcut
INFO:pytorch_transformers.modeling_utils:loading configuration file https://s3.amazonaws.

In [107]:
model = model_class.from_pretrained(args['model_name'])

INFO:pytorch_transformers.modeling_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at C:\Users\berna\.cache\torch\pytorch_transformers\b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.3d5adf10d3445c36ce131f4c6416aa62e9b58e1af56b97664773f4858a46286e
INFO:pytorch_transformers.modeling_utils:Model config {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "output_attentions": false,
  "output_hidden_states": false,
  "pruned_heads": {},
  "torchscript": false,
  "type_vocab_size": 2,
  "vocab_size": 28996
}

INFO:pytorch_transformers.file_utils:https://s3.amazonaws.com

 26%|██▋       | 114754560/435779157 [00:09<00:24, 13094110.57B/s][A[A[A


 27%|██▋       | 116074496/435779157 [00:09<00:24, 13090272.19B/s][A[A[A


 27%|██▋       | 117401600/435779157 [00:09<00:24, 13070763.45B/s][A[A[A


 27%|██▋       | 118728704/435779157 [00:09<00:24, 13100962.21B/s][A[A[A


 28%|██▊       | 120039424/435779157 [00:09<00:24, 13095551.83B/s][A[A[A


 28%|██▊       | 121350144/435779157 [00:09<00:24, 13065826.91B/s][A[A[A


 28%|██▊       | 122677248/435779157 [00:09<00:23, 13057036.89B/s][A[A[A


 28%|██▊       | 124004352/435779157 [00:10<00:23, 13084308.38B/s][A[A[A


 29%|██▉       | 125315072/435779157 [00:10<00:23, 13053023.31B/s][A[A[A


 29%|██▉       | 126630912/435779157 [00:10<00:23, 13071623.22B/s][A[A[A


 29%|██▉       | 127952896/435779157 [00:10<00:23, 13057053.56B/s][A[A[A


 30%|██▉       | 129288192/435779157 [00:10<00:23, 13108825.34B/s][A[A[A


 30%|██▉       | 130607104/435779157 [00:10<00:23, 13076507.52B/

 58%|█████▊    | 253684736/435779157 [00:20<00:13, 13045828.24B/s][A[A[A


 59%|█████▊    | 255010816/435779157 [00:20<00:13, 13075111.94B/s][A[A[A


 59%|█████▉    | 256321536/435779157 [00:20<00:13, 13083072.11B/s][A[A[A


 59%|█████▉    | 257632256/435779157 [00:20<00:13, 13081586.35B/s][A[A[A


 59%|█████▉    | 258942976/435779157 [00:20<00:13, 13024689.91B/s][A[A[A


 60%|█████▉    | 260286464/435779157 [00:20<00:13, 13107807.20B/s][A[A[A


 60%|██████    | 261598208/435779157 [00:20<00:13, 13057659.16B/s][A[A[A


 60%|██████    | 262924288/435779157 [00:20<00:13, 13110573.88B/s][A[A[A


 61%|██████    | 264236032/435779157 [00:20<00:13, 13097980.52B/s][A[A[A


 61%|██████    | 265546752/435779157 [00:20<00:13, 13054570.25B/s][A[A[A


 61%|██████    | 266872832/435779157 [00:21<00:12, 13053341.90B/s][A[A[A


 62%|██████▏   | 268199936/435779157 [00:21<00:12, 13092058.33B/s][A[A[A


 62%|██████▏   | 269510656/435779157 [00:21<00:12, 13068152.74B/

 90%|████████▉ | 391821312/435779157 [00:30<00:03, 12104238.67B/s][A[A[A


 90%|█████████ | 393144320/435779157 [00:30<00:03, 12395045.25B/s][A[A[A


 91%|█████████ | 394387456/435779157 [00:30<00:03, 11440277.54B/s][A[A[A


 91%|█████████ | 395831296/435779157 [00:31<00:03, 12173315.40B/s][A[A[A


 91%|█████████ | 397073408/435779157 [00:31<00:03, 11594156.68B/s][A[A[A


 91%|█████████▏| 398370816/435779157 [00:31<00:03, 11949844.83B/s][A[A[A


 92%|█████████▏| 399585280/435779157 [00:31<00:03, 10887798.08B/s][A[A[A


 92%|█████████▏| 400706560/435779157 [00:31<00:03, 10632539.43B/s][A[A[A


 92%|█████████▏| 401855488/435779157 [00:31<00:03, 10870748.21B/s][A[A[A


 92%|█████████▏| 402960384/435779157 [00:31<00:03, 10291749.45B/s][A[A[A


 93%|█████████▎| 404105216/435779157 [00:31<00:02, 10599629.50B/s][A[A[A


 93%|█████████▎| 405181440/435779157 [00:31<00:03, 9573064.42B/s] [A[A[A


 93%|█████████▎| 406480896/435779157 [00:32<00:02, 10389900.96B/

In [108]:
model.to(device);

In [109]:
task = args['task_name']

processor = processors[task]()
label_list = processor.get_labels()
num_labels = len(label_list)

In [110]:
def load_and_cache_examples(task, tokenizer, evaluate=False):
    processor = processors[task]()
    output_mode = args['output_mode']
    
    mode = 'dev' if evaluate else 'train'
    cached_features_file = os.path.join(args['data_dir'], f"cached_{mode}_{args['model_name']}_{args['max_seq_length']}_{task}")
    
    if os.path.exists(cached_features_file) and not args['reprocess_input_data']:
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
               
    else:
        logger.info("Creating features from dataset file at %s", args['data_dir'])
        label_list = processor.get_labels()
        examples = processor.get_dev_examples(args['data_dir']) if evaluate else processor.get_train_examples(args['data_dir'])
        
        features = convert_examples_to_features(examples, label_list, args['max_seq_length'], tokenizer, output_mode,
            cls_token_at_end=bool(args['model_type'] in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            sep_token=tokenizer.sep_token,
            cls_token_segment_id=2 if args['model_type'] in ['xlnet'] else 0,
            pad_on_left=bool(args['model_type'] in ['xlnet']),                 # pad on the left for xlnet
            pad_token_segment_id=4 if args['model_type'] in ['xlnet'] else 0)
        
        logger.info("Saving features into cached file %s", cached_features_file)
        torch.save(features, cached_features_file)
        
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    if output_mode == "classification":
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)

    dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
    return dataset

In [111]:
def train(train_dataset, model, tokenizer):
    tb_writer = SummaryWriter()
    
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args['train_batch_size'])
    
    t_total = len(train_dataloader) // args['gradient_accumulation_steps'] * args['num_train_epochs']
    
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args['weight_decay']},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args['learning_rate'], eps=args['adam_epsilon'])
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args['warmup_steps'], t_total=t_total)
    
    if args['fp16']:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args['fp16_opt_level'])
        
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args['num_train_epochs'])
    logger.info("  Total train batch size  = %d", args['train_batch_size'])
    logger.info("  Gradient Accumulation steps = %d", args['gradient_accumulation_steps'])
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args['num_train_epochs']), desc="Epoch")
    
    for _ in train_iterator:
        epoch_iterator = tqdm_notebook(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            inputs = {'input_ids':      batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2] if args['model_type'] in ['bert', 'xlnet'] else None,  # XLM don't use segment_ids
                      'labels':         batch[3]}
            outputs = model(**inputs)
            loss = outputs[0]  # model outputs are always tuple in pytorch-transformers (see doc)
            print("\r%f" % loss, end='')

            if args['gradient_accumulation_steps'] > 1:
                loss = loss / args['gradient_accumulation_steps']

            if args['fp16']:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args['max_grad_norm'])
                
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args['max_grad_norm'])

            tr_loss += loss.item()
            if (step + 1) % args['gradient_accumulation_steps'] == 0:
                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                model.zero_grad()
                global_step += 1

                if args['logging_steps'] > 0 and global_step % args['logging_steps'] == 0:
                    # Log metrics
                    if args['evaluate_during_training']:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args['logging_steps'], global_step)
                    logging_loss = tr_loss

                if args['save_steps'] > 0 and global_step % args['save_steps'] == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args['output_dir'], 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    logger.info("Saving model checkpoint to %s", output_dir)


    return global_step, tr_loss / global_step

In [112]:
from sklearn.metrics import mean_squared_error, matthews_corrcoef, confusion_matrix
from scipy.stats import pearsonr

def get_mismatched(labels, preds):
    mismatched = labels != preds
    examples = processor.get_dev_examples(args['data_dir'])
    wrong = [i for (i, v) in zip(examples, mismatched) if v]
    
    return wrong

def get_eval_report(labels, preds):
    mcc = matthews_corrcoef(labels, preds)
    tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
    return {
        "mcc": mcc,
        "tp": tp,
        "tn": tn,
        "fp": fp,
        "fn": fn
    }, get_mismatched(labels, preds)

def compute_metrics(task_name, preds, labels):
    assert len(preds) == len(labels)
    return get_eval_report(labels, preds)

def evaluate(model, tokenizer, prefix=""):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args['output_dir']

    results = {}
    EVAL_TASK = args['task_name']

    eval_dataset = load_and_cache_examples(EVAL_TASK, tokenizer, evaluate=True)
    if not os.path.exists(eval_output_dir):
        os.makedirs(eval_output_dir)


    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args['eval_batch_size'])

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args['eval_batch_size'])
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = None
    out_label_ids = None
    for batch in tqdm_notebook(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(device) for t in batch)

        with torch.no_grad():
            inputs = {'input_ids':      batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2] if args['model_type'] in ['bert', 'xlnet'] else None,  # XLM don't use segment_ids
                      'labels':         batch[3]}
            outputs = model(**inputs)
            tmp_eval_loss, logits = outputs[:2]

            eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
        if preds is None:
            preds = logits.detach().cpu().numpy()
            out_label_ids = inputs['labels'].detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps
    if args['output_mode'] == "classification":
        preds = np.argmax(preds, axis=1)
    elif args['output_mode'] == "regression":
        preds = np.squeeze(preds)
    result, wrong = compute_metrics(EVAL_TASK, preds, out_label_ids)
    results.update(result)

    output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return results, wrong

In [113]:
if args['do_train']:
    train_dataset = load_and_cache_examples(task, tokenizer)
    global_step, tr_loss = train(train_dataset, model, tokenizer)
    logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

INFO:__main__:Creating features from dataset file at ../datasets/yelp_review_polarity_csv/



  0%|          | 0/560000 [00:00<?, ?it/s][A[A[A


  0%|          | 1/560000 [00:20<3135:17:50, 20.16s/it][A[A[A


  0%|          | 2/560000 [00:20<2199:31:17, 14.14s/it][A[A[A


  0%|          | 1201/560000 [00:20<1536:22:28,  9.90s/it][A[A[A


  0%|          | 1501/560000 [00:20<1074:54:10,  6.93s/it][A[A[A


  0%|          | 2001/560000 [00:20<751:46:21,  4.85s/it] [A[A[A


  0%|          | 2269/560000 [00:20<526:00:25,  3.40s/it][A[A[A


  0%|          | 2531/560000 [00:20<368:03:32,  2.38s/it][A[A[A


  1%|          | 3201/560000 [00:21<257:20:34,  1.66s/it][A[A[A


  1%|          | 3488/560000 [00:21<180:03:52,  1.16s/it][A[A[A


  1%|          | 4101/560000 [00:21<125:54:59,  1.23it/s][A[A[A


  1%|          | 4435/560000 [00:21<88:06:31,  1.75it/s] [A[A[A


  1%|          | 4801/560000 [00:21<61:39:33,  2.50it/s][A[A[A


  1%|          | 5201/5600

  9%|▉         | 51544/560000 [00:37<02:43, 3117.68it/s][A[A[A


  9%|▉         | 52201/560000 [00:37<02:38, 3196.36it/s][A[A[A


  9%|▉         | 52539/560000 [00:37<03:07, 2706.39it/s][A[A[A


  9%|▉         | 52835/560000 [00:37<03:43, 2269.54it/s][A[A[A


 10%|▉         | 53501/560000 [00:37<03:05, 2732.70it/s][A[A[A


 10%|▉         | 53901/560000 [00:37<03:02, 2768.11it/s][A[A[A


 10%|▉         | 54501/560000 [00:38<03:17, 2553.43it/s][A[A[A


 10%|▉         | 55101/560000 [00:38<02:46, 3030.34it/s][A[A[A


 10%|▉         | 55465/560000 [00:38<02:54, 2896.41it/s][A[A[A


 10%|▉         | 55801/560000 [00:38<02:48, 2995.37it/s][A[A[A


 10%|█         | 56401/560000 [00:38<02:36, 3220.71it/s][A[A[A


 10%|█         | 56750/560000 [00:39<04:39, 1797.70it/s][A[A[A


 10%|█         | 57701/560000 [00:39<04:10, 2008.77it/s][A[A[A


 10%|█         | 58401/560000 [00:39<03:19, 2516.25it/s][A[A[A


 10%|█         | 58776/560000 [00:39<04:03, 2059

 18%|█▊        | 102029/560000 [00:55<02:36, 2919.57it/s][A[A[A


 18%|█▊        | 102401/560000 [00:56<02:38, 2879.06it/s][A[A[A


 18%|█▊        | 102801/560000 [00:56<03:12, 2370.19it/s][A[A[A


 18%|█▊        | 103101/560000 [00:56<03:46, 2016.07it/s][A[A[A


 19%|█▊        | 104001/560000 [00:56<03:03, 2486.12it/s][A[A[A


 19%|█▊        | 104325/560000 [00:56<02:55, 2603.71it/s][A[A[A


 19%|█▊        | 104640/560000 [00:56<03:13, 2356.49it/s][A[A[A


 19%|█▊        | 104919/560000 [00:57<03:23, 2236.21it/s][A[A[A


 19%|█▉        | 105174/560000 [00:57<03:26, 2200.54it/s][A[A[A


 19%|█▉        | 105416/560000 [00:57<04:17, 1762.90it/s][A[A[A


 19%|█▉        | 106101/560000 [00:57<03:46, 2003.73it/s][A[A[A


 19%|█▉        | 106334/560000 [00:57<03:49, 1975.07it/s][A[A[A


 19%|█▉        | 107101/560000 [00:57<03:03, 2468.33it/s][A[A[A


 19%|█▉        | 107433/560000 [00:58<03:27, 2181.37it/s][A[A[A


 19%|█▉        | 107901/560000 [00

 27%|██▋       | 150101/560000 [01:16<03:02, 2241.28it/s][A[A[A


 27%|██▋       | 150501/560000 [01:16<02:48, 2433.64it/s][A[A[A


 27%|██▋       | 150787/560000 [01:16<03:07, 2184.99it/s][A[A[A


 27%|██▋       | 151040/560000 [01:17<03:31, 1935.43it/s][A[A[A


 27%|██▋       | 151501/560000 [01:17<03:02, 2241.89it/s][A[A[A


 27%|██▋       | 151764/560000 [01:17<03:04, 2210.21it/s][A[A[A


 27%|██▋       | 152012/560000 [01:17<03:49, 1776.74it/s][A[A[A


 27%|██▋       | 152501/560000 [01:17<03:07, 2173.63it/s][A[A[A


 27%|██▋       | 152901/560000 [01:17<02:59, 2265.75it/s][A[A[A


 27%|██▋       | 153175/560000 [01:17<03:28, 1946.64it/s][A[A[A


 27%|██▋       | 153601/560000 [01:18<02:56, 2300.19it/s][A[A[A


 27%|██▋       | 153886/560000 [01:18<03:05, 2187.14it/s][A[A[A


 28%|██▊       | 154145/560000 [01:18<03:43, 1816.40it/s][A[A[A


 28%|██▊       | 154401/560000 [01:18<03:52, 1743.54it/s][A[A[A


 28%|██▊       | 154603/560000 [01

 35%|███▌      | 197901/560000 [01:36<02:45, 2187.08it/s][A[A[A


 35%|███▌      | 198199/560000 [01:36<02:48, 2142.02it/s][A[A[A


 35%|███▌      | 198469/560000 [01:36<03:14, 1862.70it/s][A[A[A


 36%|███▌      | 198901/560000 [01:36<02:51, 2111.56it/s][A[A[A


 36%|███▌      | 199155/560000 [01:36<03:08, 1910.17it/s][A[A[A


 36%|███▌      | 199401/560000 [01:37<03:22, 1778.48it/s][A[A[A


 36%|███▌      | 199901/560000 [01:37<02:45, 2182.35it/s][A[A[A


 36%|███▌      | 200201/560000 [01:37<03:40, 1634.19it/s][A[A[A


 36%|███▌      | 200429/560000 [01:37<03:39, 1640.85it/s][A[A[A


 36%|███▌      | 201201/560000 [01:37<02:54, 2061.27it/s][A[A[A


 36%|███▌      | 201502/560000 [01:37<02:54, 2059.26it/s][A[A[A


 36%|███▌      | 202101/560000 [01:38<02:28, 2404.22it/s][A[A[A


 36%|███▌      | 202410/560000 [01:38<02:45, 2155.52it/s][A[A[A


 36%|███▌      | 202823/560000 [01:38<02:22, 2511.89it/s][A[A[A


 36%|███▋      | 203201/560000 [01

 44%|████▎     | 244681/560000 [01:54<02:25, 2164.36it/s][A[A[A


 44%|████▍     | 245301/560000 [01:54<01:59, 2635.74it/s][A[A[A


 44%|████▍     | 245638/560000 [01:55<01:55, 2732.93it/s][A[A[A


 44%|████▍     | 246001/560000 [01:55<01:50, 2849.60it/s][A[A[A


 44%|████▍     | 246324/560000 [01:55<02:02, 2553.96it/s][A[A[A


 44%|████▍     | 246612/560000 [01:55<02:09, 2422.50it/s][A[A[A


 44%|████▍     | 246879/560000 [01:55<02:29, 2099.84it/s][A[A[A


 44%|████▍     | 247401/560000 [01:55<02:14, 2332.09it/s][A[A[A


 44%|████▍     | 248001/560000 [01:55<01:51, 2802.95it/s][A[A[A


 44%|████▍     | 248341/560000 [01:55<02:01, 2569.20it/s][A[A[A


 44%|████▍     | 248644/560000 [01:56<02:06, 2468.05it/s][A[A[A


 44%|████▍     | 248924/560000 [01:56<02:40, 1941.21it/s][A[A[A


 44%|████▍     | 249159/560000 [01:56<02:39, 1947.86it/s][A[A[A


 45%|████▍     | 249601/560000 [01:56<02:13, 2317.67it/s][A[A[A


 45%|████▍     | 249901/560000 [01

 53%|█████▎    | 295001/560000 [02:14<01:42, 2573.71it/s][A[A[A


 53%|█████▎    | 295306/560000 [02:14<02:02, 2156.23it/s][A[A[A


 53%|█████▎    | 296001/560000 [02:15<01:39, 2641.00it/s][A[A[A


 53%|█████▎    | 296350/560000 [02:15<01:59, 2202.18it/s][A[A[A


 53%|█████▎    | 296642/560000 [02:15<01:54, 2300.55it/s][A[A[A


 53%|█████▎    | 297101/560000 [02:15<01:54, 2302.43it/s][A[A[A


 53%|█████▎    | 297368/560000 [02:15<01:54, 2300.91it/s][A[A[A


 53%|█████▎    | 297624/560000 [02:15<01:51, 2361.63it/s][A[A[A


 53%|█████▎    | 298101/560000 [02:15<01:34, 2772.36it/s][A[A[A


 53%|█████▎    | 298419/560000 [02:16<01:32, 2835.10it/s][A[A[A


 53%|█████▎    | 298801/560000 [02:16<01:45, 2475.54it/s][A[A[A


 53%|█████▎    | 299201/560000 [02:16<01:41, 2560.99it/s][A[A[A


 53%|█████▎    | 299501/560000 [02:16<01:53, 2299.58it/s][A[A[A


 54%|█████▎    | 300001/560000 [02:16<01:52, 2312.97it/s][A[A[A


 54%|█████▎    | 300501/560000 [02

 61%|██████▏   | 344101/560000 [02:34<01:28, 2437.92it/s][A[A[A


 62%|██████▏   | 344501/560000 [02:34<01:27, 2465.39it/s][A[A[A


 62%|██████▏   | 344801/560000 [02:34<01:26, 2489.95it/s][A[A[A


 62%|██████▏   | 345068/560000 [02:34<01:31, 2346.00it/s][A[A[A


 62%|██████▏   | 345401/560000 [02:34<01:23, 2563.58it/s][A[A[A


 62%|██████▏   | 345901/560000 [02:34<01:12, 2937.90it/s][A[A[A


 62%|██████▏   | 346226/560000 [02:34<01:31, 2334.35it/s][A[A[A


 62%|██████▏   | 346601/560000 [02:35<01:36, 2205.24it/s][A[A[A


 62%|██████▏   | 347201/560000 [02:35<01:29, 2384.25it/s][A[A[A


 62%|██████▏   | 347601/560000 [02:35<01:22, 2570.90it/s][A[A[A


 62%|██████▏   | 347901/560000 [02:35<01:28, 2410.20it/s][A[A[A


 62%|██████▏   | 348201/560000 [02:35<01:34, 2229.56it/s][A[A[A


 62%|██████▏   | 348701/560000 [02:35<01:27, 2424.79it/s][A[A[A


 62%|██████▏   | 348959/560000 [02:35<01:33, 2260.60it/s][A[A[A


 62%|██████▏   | 349701/560000 [02

 70%|███████   | 394301/560000 [02:55<01:06, 2510.30it/s][A[A[A


 70%|███████   | 394601/560000 [02:55<01:04, 2557.14it/s][A[A[A


 71%|███████   | 394877/560000 [02:55<01:03, 2587.40it/s][A[A[A


 71%|███████   | 395150/560000 [02:55<01:16, 2155.21it/s][A[A[A


 71%|███████   | 395701/560000 [02:55<01:18, 2085.23it/s][A[A[A


 71%|███████   | 395926/560000 [02:56<01:25, 1916.21it/s][A[A[A


 71%|███████   | 396801/560000 [02:56<01:09, 2332.41it/s][A[A[A


 71%|███████   | 397087/560000 [02:56<01:07, 2411.50it/s][A[A[A


 71%|███████   | 397401/560000 [02:56<01:29, 1818.37it/s][A[A[A


 71%|███████   | 398101/560000 [02:56<01:13, 2193.94it/s][A[A[A


 71%|███████   | 398501/560000 [02:56<01:10, 2288.96it/s][A[A[A


 71%|███████   | 398801/560000 [02:57<01:09, 2317.57it/s][A[A[A


 71%|███████▏  | 399201/560000 [02:57<01:01, 2630.11it/s][A[A[A


 71%|███████▏  | 399503/560000 [02:57<01:15, 2138.62it/s][A[A[A


 71%|███████▏  | 399901/560000 [02

 79%|███████▉  | 444463/560000 [03:15<00:48, 2377.70it/s][A[A[A


 79%|███████▉  | 445001/560000 [03:15<00:51, 2251.15it/s][A[A[A


 80%|███████▉  | 445401/560000 [03:15<00:44, 2555.14it/s][A[A[A


 80%|███████▉  | 445901/560000 [03:15<00:38, 2962.13it/s][A[A[A


 80%|███████▉  | 446260/560000 [03:15<00:46, 2430.03it/s][A[A[A


 80%|███████▉  | 446601/560000 [03:15<00:44, 2573.75it/s][A[A[A


 80%|███████▉  | 446903/560000 [03:16<00:45, 2473.73it/s][A[A[A


 80%|███████▉  | 447182/560000 [03:16<00:57, 1961.57it/s][A[A[A


 80%|████████  | 448001/560000 [03:16<00:44, 2495.66it/s][A[A[A


 80%|████████  | 448371/560000 [03:16<00:47, 2336.03it/s][A[A[A


 80%|████████  | 448691/560000 [03:16<00:53, 2080.78it/s][A[A[A


 80%|████████  | 449301/560000 [03:16<00:46, 2400.32it/s][A[A[A


 80%|████████  | 449605/560000 [03:17<00:52, 2117.55it/s][A[A[A


 80%|████████  | 450101/560000 [03:17<00:44, 2454.07it/s][A[A[A


 80%|████████  | 450402/560000 [03

 88%|████████▊ | 491101/560000 [03:32<00:26, 2570.88it/s][A[A[A


 88%|████████▊ | 491372/560000 [03:33<00:35, 1959.94it/s][A[A[A


 88%|████████▊ | 491801/560000 [03:33<00:29, 2325.36it/s][A[A[A


 88%|████████▊ | 492201/560000 [03:33<00:26, 2539.48it/s][A[A[A


 88%|████████▊ | 492496/560000 [03:33<00:26, 2593.43it/s][A[A[A


 88%|████████▊ | 492901/560000 [03:33<00:24, 2703.55it/s][A[A[A


 88%|████████▊ | 493201/560000 [03:33<00:28, 2378.16it/s][A[A[A


 88%|████████▊ | 493801/560000 [03:33<00:25, 2613.63it/s][A[A[A


 88%|████████▊ | 494085/560000 [03:33<00:24, 2645.92it/s][A[A[A


 88%|████████▊ | 494401/560000 [03:34<00:23, 2754.58it/s][A[A[A


 88%|████████▊ | 494689/560000 [03:34<00:25, 2559.21it/s][A[A[A


 88%|████████▊ | 495101/560000 [03:34<00:23, 2778.23it/s][A[A[A


 88%|████████▊ | 495401/560000 [03:34<00:23, 2793.15it/s][A[A[A


 89%|████████▊ | 495690/560000 [03:34<00:23, 2787.34it/s][A[A[A


 89%|████████▊ | 495976/560000 [03

 96%|█████████▌| 538301/560000 [03:52<00:17, 1210.93it/s][A[A[A


 96%|█████████▋| 539601/560000 [03:52<00:12, 1659.21it/s][A[A[A


 96%|█████████▋| 540111/560000 [03:52<00:11, 1770.71it/s][A[A[A


 97%|█████████▋| 540531/560000 [03:52<00:09, 2050.93it/s][A[A[A


 97%|█████████▋| 540923/560000 [03:52<00:08, 2295.80it/s][A[A[A


 97%|█████████▋| 541293/560000 [03:52<00:07, 2537.81it/s][A[A[A


 97%|█████████▋| 541653/560000 [03:53<00:07, 2357.41it/s][A[A[A


 97%|█████████▋| 542001/560000 [03:53<00:09, 1975.13it/s][A[A[A


 97%|█████████▋| 542701/560000 [03:53<00:07, 2405.30it/s][A[A[A


 97%|█████████▋| 543031/560000 [03:53<00:07, 2339.24it/s][A[A[A


 97%|█████████▋| 543401/560000 [03:53<00:06, 2523.47it/s][A[A[A


 97%|█████████▋| 543901/560000 [03:53<00:06, 2639.68it/s][A[A[A


 97%|█████████▋| 544201/560000 [03:54<00:05, 2649.77it/s][A[A[A


 97%|█████████▋| 544501/560000 [03:54<00:05, 2649.86it/s][A[A[A


 97%|█████████▋| 544784/560000 [03

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=70000.0, style=ProgressStyle(description_…

0.834361



0.135294

KeyboardInterrupt: 

In [None]:
if args['do_train']:
    if not os.path.exists(args['output_dir']):
            os.makedirs(args['output_dir'])
    logger.info("Saving model checkpoint to %s", args['output_dir'])
    
    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
    model_to_save.save_pretrained(args['output_dir'])
    tokenizer.save_pretrained(args['output_dir'])
    torch.save(args, os.path.join(args['output_dir'], 'training_args.bin'))

In [None]:
results = {}
if args['do_eval']:
    checkpoints = [args['output_dir']]
    if args['eval_all_checkpoints']:
        checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args['output_dir'] + '/**/' + WEIGHTS_NAME, recursive=True)))
        logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
    logger.info("Evaluate the following checkpoints: %s", checkpoints)
    for checkpoint in checkpoints:
        global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
        model = model_class.from_pretrained(checkpoint)
        model.to(device)
        result, wrong_preds = evaluate(model, tokenizer, prefix=global_step)
        result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
        results.update(result)

In [None]:
results