## BaseLine for bert-project

### Procedure
1. convert the `dev.json` to four sentences: (question, correct answer), (question, wrong answer). 
For instance, original sentence is **"Tommy glided across the marble floor with ease, but slipped and fell on the wet floor because `_____` has more resistance. (A) marble floor (B) wet floor"** where correct answer is wet floor. We can convert this sentenct into: (**"Tommy glided across the marble floor with ease, but slipped and fell on the wet floor because `_____` has more resistance."**, **"wet floor"**), (**"Tommy glided across the marble floor with ease, but slipped and fell on the wet floor because `_____` has more resistance."**, **"marble floor"**)
2. using bert to predict the next sentence. In this case, we use correct and wrong answer as the next sentence to get confidence separately, pick up the higer one as the predicted answer.
3. accuracy on dev is `55.4 %`, on test is `53.3 %` using pre-trained model `bert-base-uncased`
   accuracy on dev is `54.3 %`, on test is `48.7 %` using fine-tuned model `finetuned_lm`

In [81]:
import json
import pprint
import collections
import re
import pprint

from os.path import abspath, dirname, join

strip_space = lambda s: re.sub("\s+", " ", s.strip().replace("\r", "").replace("\n", ""))
BASE_PATH = "/Users/shenweihai/Desktop/Projects/Ubuntu_exx/tutorial/bert-proj"

def load_examples(file_name: str):
    ans = []
    err = collections.defaultdict(int)
    with open(file_name, 'r') as handler:
        for line in handler:
            example = json.loads(line)
            question = example['question']
            example = json.loads(line)
            question = example['question']
            
            idx = question.index("(A)")
            Q, second = strip_space(question[0:idx]), question[idx:]
                
            ans_idx = second.index("(B)")
            ans_first, ans_second = strip_space(second[0:ans_idx].replace("(A)", "")), strip_space(second[ans_idx:].replace("(B)", ""))
            ans_first = re.sub("or\s*$|\.\s*$", "", ans_first)
            ans_first = ans_first.lstrip().rstrip()
            ans_second = re.sub("or\s*$|\.\s*$", "", ans_second)
            ans_second = ans_second.lstrip().rstrip()
            
            if example.get("answer_index", 0) == 1:
                ans_first, ans_second = ans_second, ans_first
            
            ans.append((Q, ans_first, ans_second))
    return ans

# ans = load_examples(join(BASE_PATH, "quarel-data", "quarel-v1-dev.json"))
ans = load_examples(join(BASE_PATH, "quarel-data", "quarel-v1-test.json"))
pprint.pprint(ans[0:3])

[('Does the Sun appear dimmer on Earth or Pluto?', 'Pluto', 'Earth'),
 ('Tank the kitten learned from trial and error that carpet is rougher then '
  'skin. When he scratches his claws over carpet it generates _____ then when '
  'he scratches his claws over skin',
  'more heat',
  'less heat'),
 ('An airplane takes longer to reach takeoff speed on a dirt runway than on a '
  'paved runway. This is because the dirt runway has',
  'more resistance',
  'less resistance')]


In [82]:
# get corresponding text and segments_ids
import json
from pytorch_pretrained_bert import BertTokenizer
model_name = 'bert-base-uncased'
model_name = '/Users/shenweihai/Desktop/Projects/Ubuntu_exx/tutorial/bert-proj/finetuned_lm'
tokenizer = BertTokenizer.from_pretrained(model_name)

bert_ans = []
for Q, correct, wrong in ans:
    first_sentence = tokenizer.tokenize("[CLS] {0} [SEP]".format(Q)) 
    
    tokenized_text = tokenizer.tokenize("[CLS] {0} [SEP] {1} [SEP]".format(Q, correct))
    seg_ids = [1] * len(tokenized_text)
    for idx in range(len(first_sentence)):
        seg_ids[idx] = 0
    
    tokenized_text2 = tokenizer.tokenize("[CLS] {0} [SEP] {1} [SEP]".format(Q, wrong))
    seg_ids2 = [1] * len(tokenized_text2)
    for idx in range(len(first_sentence)):
        seg_ids2[idx] = 0
        
    bert_ans.append(["[CLS] {0} [SEP] {1} [SEP]".format(Q, correct), json.dumps(seg_ids), "[CLS] {0} [SEP] {1} [SEP]".format(Q, wrong), json.dumps(seg_ids2)])

pprint.pprint(bert_ans[0:3])


[['[CLS] Does the Sun appear dimmer on Earth or Pluto? [SEP] Pluto [SEP]',
  '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1]',
  '[CLS] Does the Sun appear dimmer on Earth or Pluto? [SEP] Earth [SEP]',
  '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1]'],
 ['[CLS] Tank the kitten learned from trial and error that carpet is rougher '
  'then skin. When he scratches his claws over carpet it generates _____ then '
  'when he scratches his claws over skin [SEP] more heat [SEP]',
  '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, '
  '0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]',
  '[CLS] Tank the kitten learned from trial and error that carpet is rougher '
  'then skin. When he scratches his claws over carpet it generates _____ then '
  'when he scratches his claws over skin [SEP] less heat [SEP]',
  '[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, '
  '0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]'],
 ['[CLS] An airpla

In [83]:
# try the predict next sentence model
# https://github.com/huggingface/pytorch-pretrained-BERT/issues/48
import json
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM,BertForNextSentencePrediction,BertForSequenceClassification
import torch.nn.functional as F

model_name = 'bert-base-uncased'
model_name = '/Users/shenweihai/Desktop/Projects/Ubuntu_exx/tutorial/bert-proj/finetuned_lm'
tokenizer = BertTokenizer.from_pretrained(model_name)

# Load pre-trained model (weights)
model = BertForNextSentencePrediction.from_pretrained(model_name)
model.eval()

total, total_suc = 0, 0
# Tokenized input
for item in bert_ans:
    text, segs, text_err, segs_err =  item

    tokenized_text = tokenizer.tokenize(text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = json.loads(segs)

    tokenized_text_err = tokenizer.tokenize(text_err)
    indexed_tokens_err = tokenizer.convert_tokens_to_ids(tokenized_text_err)
    segments_ids_err = json.loads(segs_err)

    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])

    tokens_tensor_err = torch.tensor([indexed_tokens_err])
    segments_tensors_err = torch.tensor([segments_ids_err])

    # # Load pre-trained model (weights)
    # model = BertForNextSentencePrediction.from_pretrained(model_name)
    # model.eval()

    # Predict is Next Sentence ?
    predictions = model(tokens_tensor, segments_tensors )
    confidence = float(F.softmax(predictions)[0][0])

    predictions_err = model(tokens_tensor_err, segments_tensors_err)
    confidence_err = float(F.softmax(predictions_err)[0][0])

    total += 1
    if confidence > confidence_err:
        total_suc += 1
    print("current acc: %s, total_suc: %s, total: %s" % (total_suc / (total + 0.0), total_suc, total))
    # print(confidence, confidence_err)




current acc: 1.0, total_suc: 1, total: 1
current acc: 1.0, total_suc: 2, total: 2
current acc: 0.6666666666666666, total_suc: 2, total: 3
current acc: 0.5, total_suc: 2, total: 4
current acc: 0.4, total_suc: 2, total: 5
current acc: 0.3333333333333333, total_suc: 2, total: 6
current acc: 0.42857142857142855, total_suc: 3, total: 7
current acc: 0.5, total_suc: 4, total: 8
current acc: 0.4444444444444444, total_suc: 4, total: 9
current acc: 0.4, total_suc: 4, total: 10
current acc: 0.36363636363636365, total_suc: 4, total: 11
current acc: 0.4166666666666667, total_suc: 5, total: 12
current acc: 0.46153846153846156, total_suc: 6, total: 13
current acc: 0.42857142857142855, total_suc: 6, total: 14
current acc: 0.4, total_suc: 6, total: 15
current acc: 0.4375, total_suc: 7, total: 16
current acc: 0.47058823529411764, total_suc: 8, total: 17
current acc: 0.5, total_suc: 9, total: 18
current acc: 0.5263157894736842, total_suc: 10, total: 19
current acc: 0.5, total_suc: 10, total: 20
current a

current acc: 0.5337837837837838, total_suc: 79, total: 148
current acc: 0.5369127516778524, total_suc: 80, total: 149
current acc: 0.54, total_suc: 81, total: 150
current acc: 0.5364238410596026, total_suc: 81, total: 151
current acc: 0.5328947368421053, total_suc: 81, total: 152
current acc: 0.5359477124183006, total_suc: 82, total: 153
current acc: 0.5324675324675324, total_suc: 82, total: 154
current acc: 0.5290322580645161, total_suc: 82, total: 155
current acc: 0.532051282051282, total_suc: 83, total: 156
current acc: 0.535031847133758, total_suc: 84, total: 157
current acc: 0.5316455696202531, total_suc: 84, total: 158
current acc: 0.5345911949685535, total_suc: 85, total: 159
current acc: 0.5375, total_suc: 86, total: 160
current acc: 0.5341614906832298, total_suc: 86, total: 161
current acc: 0.5308641975308642, total_suc: 86, total: 162
current acc: 0.5276073619631901, total_suc: 86, total: 163
current acc: 0.524390243902439, total_suc: 86, total: 164
current acc: 0.52727272727

current acc: 0.46875, total_suc: 135, total: 288
current acc: 0.4671280276816609, total_suc: 135, total: 289
current acc: 0.4689655172413793, total_suc: 136, total: 290
current acc: 0.46735395189003437, total_suc: 136, total: 291
current acc: 0.4691780821917808, total_suc: 137, total: 292
current acc: 0.46757679180887374, total_suc: 137, total: 293
current acc: 0.46598639455782315, total_suc: 137, total: 294
current acc: 0.46440677966101696, total_suc: 137, total: 295
current acc: 0.46283783783783783, total_suc: 137, total: 296
current acc: 0.4612794612794613, total_suc: 137, total: 297
current acc: 0.46308724832214765, total_suc: 138, total: 298
current acc: 0.46488294314381273, total_suc: 139, total: 299
current acc: 0.4666666666666667, total_suc: 140, total: 300
current acc: 0.4684385382059801, total_suc: 141, total: 301
current acc: 0.47019867549668876, total_suc: 142, total: 302
current acc: 0.47194719471947194, total_suc: 143, total: 303
current acc: 0.47368421052631576, total_su

current acc: 0.4834905660377358, total_suc: 205, total: 424
current acc: 0.4823529411764706, total_suc: 205, total: 425
current acc: 0.4835680751173709, total_suc: 206, total: 426
current acc: 0.4847775175644028, total_suc: 207, total: 427
current acc: 0.48364485981308414, total_suc: 207, total: 428
current acc: 0.48484848484848486, total_suc: 208, total: 429
current acc: 0.48604651162790696, total_suc: 209, total: 430
current acc: 0.48491879350348027, total_suc: 209, total: 431
current acc: 0.4861111111111111, total_suc: 210, total: 432
current acc: 0.48498845265588914, total_suc: 210, total: 433
current acc: 0.4838709677419355, total_suc: 210, total: 434
current acc: 0.4850574712643678, total_suc: 211, total: 435
current acc: 0.48394495412844035, total_suc: 211, total: 436
current acc: 0.4851258581235698, total_suc: 212, total: 437
current acc: 0.4840182648401826, total_suc: 212, total: 438
current acc: 0.48519362186788156, total_suc: 213, total: 439
current acc: 0.4863636363636364, 