In [128]:
import os
import glob
import pandas as pd
import pickle
from collections import defaultdict
import logging
import pickle
from tqdm import tqdm
import json

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import BertForSequenceClassification, BertConfig, BertTokenizer, AdamW, WEIGHTS_NAME

from tensorflow.python.keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmup
#from tensorboardX import SummaryWriter
import numpy as np
import time
import datetime
import random
import argparse
import scipy
import sklearn
import math
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score, confusion_matrix

logger = logging.getLogger(__name__)

CUDA = (torch.cuda.device_count() > 0)
CLASSES = ['agrees','neutral','disagrees'] # correspond to 0, 1, 2 in labeled data
NUM_LABELS = 3

def simple_accuracy(preds, labels):
    return (preds == labels).mean()

def acc_and_f1(preds, labels):
    acc = simple_accuracy(preds, labels)
    micro_f1 = f1_score(y_true=labels, y_pred=preds, average='micro')
    macro_f1 = f1_score(y_true=labels, y_pred=preds, average='macro')
    return {
        "acc": acc,
        "micro_f1": micro_f1,
        "macro_f1":macro_f1,
        "acc_and_macro_f1": (acc + macro_f1) / 2,
    }

def cm(preds, labels):
    return confusion_matrix(preds, labels)

def get_pred_label(res_,to_str=False):
    if to_str:
        return CLASSES[res_.index(max(res_))]
    else:
        return res_.index(max(res_))


def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))

    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

def save_pretrained(model, save_directory):
    """ Save a model and its configuration file to a directory, so that it
        can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
    """
    assert os.path.isdir(
        save_directory
    ), "Saving path should be a directory where the model and configuration can be saved"

    # Only save the model itself if we are using distributed training
    model_to_save = model.module if hasattr(model, "module") else model

    # Attach architecture to the config
    model_to_save.config.architectures = [model_to_save.__class__.__name__]

    # Save configuration file
    model_to_save.config.save_pretrained(save_directory)
    # with open(os.path.join(save_directory,'config.json'), 'w') as outfile:
    #     json.dump(model_to_save.config, outfile)

    # If we save using the predefined names, we can load using `from_pretrained`
    output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
    torch.save(model_to_save.state_dict(), output_model_file)
    logger.info("Model weights saved in {}".format(output_model_file))


def build_dataloader(*args, sampler='random'):
    #print(args[:2])
    data = (torch.tensor(x) for x in args)
    #print(data[0])
    data = TensorDataset(*data)
    
    sampler = RandomSampler(data) if sampler == 'random' else SequentialSampler(data)
    dataloader = DataLoader(data, sampler=sampler, batch_size=1)

    return dataloader

def get_out_data(dat_path,max_seq_length=500):
    #eval_set = 'train' # can also be 'test'
    data = pd.read_csv(dat_path,
                              sep='\t',header=None)
    data.columns = ['text','label']#,'outlet']

    out = defaultdict(list)

    print('Number of examples:',len(data))
    to_predict = data.text.values
    true = data.label.values

    for dat_ix in range(0,len(data)):
        sent = to_predict[dat_ix]
        #print(sent)
        label = true[dat_ix]
        encoded_sent = tokenizer.encode(sent,add_special_tokens=True)[1:] # remove [CLS] auto-inserted at beginning
        #print('encoded sent:',encoded_sent)
        CLS_ix = encoded_sent.index(101)
        SEP_ix = encoded_sent[CLS_ix:].index(102)+CLS_ix
        out['input_ids'].append(encoded_sent)
        out['sentences'].append(sent)
        out['label'].append(label)
        out['index_CLS'].append(CLS_ix)
        out['index_SEP_after_CLS'].append(SEP_ix)
        #print(encoded_sent[CLS_ix])

    out['input_ids'] = pad_sequences(
            out['input_ids'],
            maxlen=max_seq_length,
            dtype="long",
            value=0,
            truncating="post",
            padding="post")


    print('Adding attention masks...')
    # get attn masks
    for sent_no,sent in enumerate(out['input_ids']):
        tok_type_ids = [0 for tok_id in sent]
        #print('tok type ids:',tok_type_ids
        #     )
        #mask = [int(tok_id > 0) for tok_id in sent]
        #print('old mask:',mask)
        #print('CLS index:',out['index_CLS'][sent_no])
        #print('SEP index:',out['index_SEP_after_CLS'][sent_no])
        mask = [0 if n < out['index_CLS'][sent_no] or n > out['index_SEP_after_CLS'][sent_no] else 1 
                for n,tok_id in enumerate(sent)]
        out['attention_mask'].append(mask)
        out['token_type_ids'].append(tok_type_ids)
    #print(len(out['labels']))
    #print(sum(out['labels']))

    print('Preparing input examples for prediction...')

    return out

In [129]:
OUTPUT_DIR = './'
NUM_EPOCHS = 1
LEARNING_RATE = 2e-5
SEED = 420
DATA_NAME = 'mturk_windowed_1'
DATA_DIR = os.path.join('../data_creation/scripts/save',DATA_NAME)
BASE_MODEL = 'uncased_LM_cc_output'
MODELS_DIR = '../BERT/LM_finetuned'
MODEL_NAME_OR_PATH = os.path.join(MODELS_DIR,BASE_MODEL)
print('model name/path:',MODEL_NAME_OR_PATH)
EVAL_ON_TEST = False
PRED_FILE_NAME = 'dev_preds'
DO_TRAIN = True
DO_EVAL = True
LOCAL_RANK = -1
NO_CUDA = False
print('data dir:',DATA_DIR)

model name/path: ../BERT/LM_finetuned/uncased_LM_cc_output
data dir: ../data_creation/scripts/save/mturk_windowed_1


In [145]:
if (
        os.path.exists(OUTPUT_DIR)
        and os.listdir(OUTPUT_DIR)
        and DO_TRAIN
        and not OUTPUT_DIR
):
    raise ValueError(
        "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
            OUTPUT_DIR
        )
    )

# Setup CUDA, GPU & distributed training
if LOCAL_RANK == -1 or NO_CUDA:
    device = torch.device("cuda" if torch.cuda.is_available() and not NO_CUDA else "cpu")
    N_GPU = torch.cuda.device_count()
else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.cuda.set_device(LOCAL_RANK)
    device = torch.device("cuda", LOCAL_RANK)
    torch.distributed.init_process_group(backend="nccl")
    N_GPU = 1
DEVICE = device

# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO if LOCAL_RANK in [-1, 0] else logging.WARN,
)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Load model
# config = BertConfig.from_pretrained(MODEL_NAME_OR_PATH, num_labels=NUM_LABELS)
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
# model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME_OR_PATH,
#                                                           config=config)
config = BertConfig.from_pretrained('bert-base-uncased', num_labels=3)
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=3,
    output_attentions=True,
    output_hidden_states=False,
)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)#ARGS.casing)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=1e-8)
model.to(DEVICE)

04/15/2020 13:43:55 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /Users/yiweiluo/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.8f56353af4a709bf5ff0fbc915d8f5b42bfff892cbb6ac98c3c45f481a03c685
04/15/2020 13:43:55 - INFO - transformers.configuration_utils -   Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "do_sample": false,
  "eos_token_ids": 0,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddi

Loading config...
Config path: bert-base-uncased


04/15/2020 13:43:55 - INFO - transformers.modeling_utils -   loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin from cache at /Users/yiweiluo/.cache/torch/transformers/aa1ef1aede4482d0dbcd4d52baad8ae300e60902e88fcb0bebdec09afd232066.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157
04/15/2020 13:43:58 - INFO - transformers.modeling_utils -   Weights of BertForSequenceClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias']
04/15/2020 13:43:58 - INFO - transformers.modeling_utils -   Weights from pretrained model not used in BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
04/15/2020 13:43:58 - INFO - transfor

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [12]:
v_ix, v = [],[]
with open('vocab.txt-vocab.txt','r') as f:
    for ix,line in enumerate(f):
        v_ix.append(ix)
        v.append(line.strip())
vocab_dict = dict(zip(v_ix,v))

In [131]:
# Load data for prediction/eval
eval_set = 'test' if EVAL_ON_TEST else 'dev' # can also be 'test'
eval_dat_path = os.path.join(DATA_DIR,eval_set+'.tsv')
eval_data = get_out_data(eval_dat_path,max_seq_length=500)

test_inputs, test_labels, test_masks = eval_data['input_ids'], eval_data['label'], eval_data['attention_mask']
test_dataloader = build_dataloader(
    test_inputs, test_labels, test_masks,
    sampler='order')

Number of examples: 130
Adding attention masks...
Preparing input examples for prediction...


In [134]:
test_inputs[0]

array([  102,  2037,  2817,  1010,  2405,  6928,  1999,  1996,  3485,
        3267,  4785,  2689,  1010,  2758,  2008,  2087,  2436,  3121,
        2275,  7715,  2241,  2006,  1037,  5109,  1011,  2214,  5675,
        2008,  3594,  1996, 21453,  6165,  1997,  2273,  1012,   102,
         101,  3121,  2323,  2024,  8566,  3401,  5907,  1011,  5860,
       20026, 19185, 13827,  1999,  9829,  7216,  2050,  2138,  4292,
        7715,  2012,  3621, 16676,  3798,  2064,  2393,  4337,  3795,
       12959,  1012,   102,  1523,  1999,  1037,  2843,  1997,  3121,
        1010,  2017,  2156,  2943,  8381,  2003,  1037,  2843,  3020,
        2138,  1996,  3115,  2003, 10250, 12322,  9250,  2005,  2273,
        1521,  1055,  2303,  3684,  2537,  1010,  1524,  2056, 11235,
        2332,  2863,  1010,  1037,  2522,  1011,  3166,  1997,  1996,
        2817,  1998,  1037, 16012, 21281, 19570,  2923,  2012,  5003,
       14083, 13149,  2102,  2118,  2966,  2415,  1999,  1996,  4549,
        1012,   102,

In [106]:
eval_dat_path

'../data_creation/scripts/save/mturk_windowed_1/dev.tsv'

In [107]:
tr = pd.read_csv('../data_creation/scripts/save/mturk_windowed_1/train.tsv',sep='\t',header=None)

In [119]:
tr.shape

(1495, 2)

In [127]:
from collections import Counter
Counter(train_labels)

Counter({0: 556, 1: 618, 2: 321})

In [132]:
# Prepare training data
if DO_TRAIN:
    if not os.path.exists(OUTPUT_DIR) and LOCAL_RANK in [-1, 0]:
        os.makedirs(OUTPUT_DIR)

    logger.info("Saving model checkpoint to %s", OUTPUT_DIR)
    # Save a trained model, configuration and tokenizer using `save_pretrained()`.
    # They can then be reloaded using `from_pretrained()`
    model_to_save = (
        model.module if hasattr(model, "module") else model
    )  # Take care of distributed/parallel training
    save_pretrained(model_to_save,OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)

    # Good practice: save your training arguments together with the trained model
    #torch.save(ARGS, os.path.join(ARGS.output_dir, "training_args.bin"))

#     if os.path.exists(OUTPUT_DIR + "/data.cache.pkl"):
#         data = pickle.load(open(OUTPUT_DIR + "/data.cache.pkl", 'rb'))
#     else:
    data = get_out_data(os.path.join(DATA_DIR, 'train.tsv'))
    pickle.dump(data, open(OUTPUT_DIR + "/data.cache.pkl", 'wb'))

    train_inputs, train_labels, train_masks = data['input_ids'], data['label'], data['attention_mask']
    train_dataloader = build_dataloader(
        train_inputs, train_labels, train_masks)

    total_steps = len(train_dataloader) * NUM_EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=total_steps)

04/15/2020 13:26:58 - INFO - __main__ -   Saving model checkpoint to ./
04/15/2020 13:26:58 - INFO - transformers.configuration_utils -   Configuration saved in ./config.json
04/15/2020 13:26:59 - INFO - __main__ -   Model weights saved in ./pytorch_model.bin


Number of examples: 1495
Adding attention masks...
Preparing input examples for prediction...


In [143]:
for step, batch in enumerate(train_dataloader):
    input_ids, labels, masks = batch
    print(step,labels)
#     outputs = model(
#         input_ids,
#         attention_mask=masks,
#         labels=labels)

0 tensor([2])
1 tensor([0])
2 tensor([0])
3 tensor([1])
4 tensor([0])
5 tensor([1])
6 tensor([1])
7 tensor([2])
8 tensor([0])
9 tensor([0])
10 tensor([2])
11 tensor([1])
12 tensor([1])
13 tensor([1])
14 tensor([2])
15 tensor([2])
16 tensor([1])
17 tensor([1])
18 tensor([2])
19 tensor([1])
20 tensor([1])
21 tensor([0])
22 tensor([0])
23 tensor([0])
24 tensor([1])
25 tensor([2])
26 tensor([0])
27 tensor([0])
28 tensor([1])
29 tensor([1])
30 tensor([2])
31 tensor([2])
32 tensor([1])
33 tensor([0])
34 tensor([1])
35 tensor([1])
36 tensor([2])
37 tensor([0])
38 tensor([1])
39 tensor([1])
40 tensor([1])
41 tensor([0])
42 tensor([1])
43 tensor([2])
44 tensor([0])
45 tensor([1])
46 tensor([1])
47 tensor([1])
48 tensor([1])
49 tensor([1])
50 tensor([0])
51 tensor([1])
52 tensor([2])
53 tensor([2])
54 tensor([1])
55 tensor([1])
56 tensor([2])
57 tensor([1])
58 tensor([1])
59 tensor([0])
60 tensor([1])
61 tensor([0])
62 tensor([1])
63 tensor([0])
64 tensor([2])
65 tensor([1])
66 tensor([1])
67 te

742 tensor([1])
743 tensor([1])
744 tensor([1])
745 tensor([2])
746 tensor([1])
747 tensor([1])
748 tensor([1])
749 tensor([2])
750 tensor([2])
751 tensor([2])
752 tensor([2])
753 tensor([1])
754 tensor([1])
755 tensor([0])
756 tensor([2])
757 tensor([1])
758 tensor([0])
759 tensor([2])
760 tensor([0])
761 tensor([1])
762 tensor([0])
763 tensor([1])
764 tensor([0])
765 tensor([0])
766 tensor([1])
767 tensor([0])
768 tensor([1])
769 tensor([1])
770 tensor([2])
771 tensor([1])
772 tensor([0])
773 tensor([1])
774 tensor([1])
775 tensor([2])
776 tensor([2])
777 tensor([2])
778 tensor([0])
779 tensor([2])
780 tensor([0])
781 tensor([0])
782 tensor([2])
783 tensor([1])
784 tensor([0])
785 tensor([1])
786 tensor([1])
787 tensor([0])
788 tensor([1])
789 tensor([0])
790 tensor([0])
791 tensor([2])
792 tensor([2])
793 tensor([0])
794 tensor([1])
795 tensor([0])
796 tensor([0])
797 tensor([0])
798 tensor([0])
799 tensor([0])
800 tensor([1])
801 tensor([0])
802 tensor([2])
803 tensor([2])
804 tens

1418 tensor([2])
1419 tensor([0])
1420 tensor([2])
1421 tensor([0])
1422 tensor([1])
1423 tensor([0])
1424 tensor([0])
1425 tensor([1])
1426 tensor([0])
1427 tensor([0])
1428 tensor([2])
1429 tensor([0])
1430 tensor([0])
1431 tensor([2])
1432 tensor([2])
1433 tensor([1])
1434 tensor([2])
1435 tensor([1])
1436 tensor([1])
1437 tensor([1])
1438 tensor([0])
1439 tensor([0])
1440 tensor([0])
1441 tensor([1])
1442 tensor([1])
1443 tensor([0])
1444 tensor([1])
1445 tensor([1])
1446 tensor([2])
1447 tensor([1])
1448 tensor([0])
1449 tensor([1])
1450 tensor([0])
1451 tensor([2])
1452 tensor([2])
1453 tensor([0])
1454 tensor([1])
1455 tensor([0])
1456 tensor([1])
1457 tensor([1])
1458 tensor([1])
1459 tensor([1])
1460 tensor([0])
1461 tensor([2])
1462 tensor([0])
1463 tensor([0])
1464 tensor([1])
1465 tensor([0])
1466 tensor([1])
1467 tensor([1])
1468 tensor([0])
1469 tensor([1])
1470 tensor([0])
1471 tensor([2])
1472 tensor([1])
1473 tensor([2])
1474 tensor([2])
1475 tensor([0])
1476 tensor([2

In [136]:
input_ids,labels,masks

(tensor([[  102,  1996,  3857,  6279,  1997,  2522,  2475,  3310,  2013,  5255,
          10725, 20145,  1010,  2004,  2092,  2004,  2013,  2455,  1011,  2224,
           3431,  1012,   102,   101,  2122,  2047,  3463,  2097,  2393,  2149,
           3437,  2590,  3980,  1999,  3408,  1997,  2119,  2712,  1011,  2504,
           4125,  1998,  2129,  1996,  4774,  1005,  1055,  3147,  4655,  2024,
          14120,  2000,  3795,  2689,  1012,   102, 17512,  3137, 21193,  2031,
           2468,  2426,  1996,  2087,  4069, 18407,  1997,  3795, 12959,  1012,
            102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,   

In [137]:
outputs

(tensor(0.6573, grad_fn=<MultiMarginLossBackward>),
 tensor([[-0.3640, -0.1409,  0.0541]], grad_fn=<AddmmBackward>))

In [138]:
loss,logits = outputs

In [139]:
#labels = labels.cpu().numpy()
#input_ids = input_ids.cpu().numpy()
preds = scipy.special.softmax(logits.detach().numpy(), axis=1)
input_toks = [
    tokenizer.convert_ids_to_tokens(s) for s in input_ids
]

for seq, label, pred in zip(input_toks, labels, preds):
    sep_char = '+' if np.argmax(pred) == label else '-'
    print(sep_char * 40 + '\n')
    print(' '.join(seq) + '\n')
    print('label: ' + str(label) + '\n')
    print('pred: ' + str(np.argmax(pred)) + '\n')
    print('dist: ' + str(pred) + '\n')
    print('\n\n')

----------------------------------------

[SEP] the build ##up of co ##2 comes from burning fossil fuels , as well as from land - use changes . [SEP] [CLS] these new results will help us answer important questions in terms of both sea - level rise and how the planet ' s cold regions are responding to global change . [SEP] retreating mountain glaciers have become among the most prominent icons of global warming . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

In [None]:
for epoch_i in range(0, ARGS.num_epochs):

        # ========================================
        #               Training
        # ========================================
        print("")
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, ARGS.num_epochs))
        print('Training...')

        losses = []
        t0 = time.time()
        model.train()
        for step, batch in enumerate(train_dataloader):
            #print(step,batch)

            if step % 40 == 0 and not step == 0:
                elapsed = format_time(time.time() - t0)
                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}. Loss: {:.2f}'.format(
                    step, len(train_dataloader), elapsed, float(np.mean(losses))))

            if CUDA:
                batch = (x.cuda() for x in batch)
            input_ids, labels, masks = batch
            model.zero_grad()

            outputs = model(
                input_ids,
                attention_mask=masks,
                labels=labels)

            #print(len(outputs))

            #loss, _, _ = outputs
            loss, _ = outputs
            losses.append(loss.item())

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

        avg_loss = np.mean(losses)
        #writer.add_scalar('train/loss', np.mean(avg_loss), epoch_i)

        print("")
        print("  Average training loss: {0:.2f}".format(avg_loss))
        print("  Training epcoh took: {:}".format(format_time(time.time() - t0)))

# ========================================
#               Validation
# ========================================
print("")
print("Running Validation...")

t0 = time.time()
model.eval()
losses = []
all_preds = []
all_labels = []
log = open(ARGS.output_dir + '/log', 'w')
for step, batch in enumerate(test_dataloader):

    if CUDA:
        batch = (x.cuda() for x in batch)
    input_ids, labels, masks = batch

    with torch.no_grad():
        outputs = model(
            input_ids,
            attention_mask=masks,
            labels=labels)
    #loss, logits, attns = outputs
    loss, logits = outputs

    losses.append(loss.item())

    labels = labels.cpu().numpy()
    input_ids = input_ids.cpu().numpy()
    preds = scipy.special.softmax(logits.cpu().numpy(), axis=1)
    input_toks = [
        tokenizer.convert_ids_to_tokens(s) for s in input_ids
    ]

    for seq, label, pred in zip(input_toks, labels, preds):
        sep_char = '+' if np.argmax(pred) == label else '-'
        log.write(sep_char * 40 + '\n')
        log.write(' '.join(seq) + '\n')
        log.write('label: ' + str(label) + '\n')
        log.write('pred: ' + str(np.argmax(pred)) + '\n')
        log.write('dist: ' + str(pred) + '\n')
        log.write('\n\n')

        all_preds += [pred]
        all_labels += [label]
log.close()
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

avg_loss = np.mean(losses)
f1 = sklearn.metrics.f1_score(all_labels, np.argmax(all_preds, axis=1),average='macro')
acc = sklearn.metrics.accuracy_score(all_labels, np.argmax(all_preds, axis=1))
#auc = sklearn.metrics.roc_auc_score(all_labels, all_preds[:, 1])

#writer.add_scalar('eval/acc', acc, epoch_i)
#writer.add_scalar('eval/auc', auc, epoch_i)
#writer.add_scalar('eval/f1', f1, epoch_i)
#writer.add_scalar('eval/loss', f1, epoch_i)

print("  Loss: {0:.2f}".format(avg_loss))
print("  Accuracy: {0:.2f}".format(acc))
print("  F1: {0:.2f}".format(f1))
#print("  AUC: {0:.2f}".format(auc))
print("  Validation took: {:}".format(format_time(time.time() - t0)))

# Want to save: model, config, vocab, eval results, preds, training_args,
result = {'acc': acc,
    'f1':f1,
    'cm':cm(np.argmax(all_preds, axis=1),all_labels)}

preds_df = pd.DataFrame({'true':all_labels,
'predicted':np.argmax(all_preds, axis=1)})
preds_df.to_csv(ARGS.output_dir+'/{}.tsv'.format(ARGS.pred_file_name),sep='\t',index=False)

output_eval_file = os.path.join(ARGS.output_dir, "eval_results_{}.txt".format(ARGS.pred_file_name))
with open(output_eval_file, "w") as writer:
    logger.info("***** Eval results *****")
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(result[key]))
        writer.write("%s = %s\n" % (key, str(result[key])))
