In [1]:
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer
from transformers import BertForNextSentencePrediction, AdamW, BertConfig
from transformers import get_linear_schedule_with_warmup
from keras_preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
import argparse
import random
import re

import warnings
warnings.filterwarnings('ignore')

### 定义loss函数

In [2]:
def MarginRankingLoss(p_scores, n_scores):
    margin = 1
    scores = margin - p_scores + n_scores
    scores = scores.clamp(min=0)

    return scores.mean()

### 加载分词器

In [3]:
device = 0
# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

Loading BERT tokenizer...


### 加载数据

In [4]:
sample_num_memory = []
examples = []

""" 加载所有句子的样本数 """
for line in open('data/training_data/dailydial_sample_num.txt'):
    line = line.strip()
    sample_num_memory.append(int(line))
    
""" 加载所有句子的所有样本 """
for line in open('data/training_data/dailydial_pairs.txt', encoding='utf-8'):
    line = line.strip().split('\t\t')
    sent1 = line[0]
    sent2 = line[1]
    examples.append((sent1, sent2))

In [5]:
sample_num_memory, len(sample_num_memory)

([2,
  3,
  3,
  3,
  2,
  2,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  2,
  2,
  3,
  2,
  3,
  3,
  3,
  2,
  3,
  3,
  2,
  3,
  3,
  3,
  3,
  2,
  2,
  3,
  3,
  3,
  2,
  2,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  2,
  2,
  2,
  2,
  3,
  3,
  3,
  3,
  2,
  3,
  3,
  3,
  3,
  3,
  2,
  3,
  2,
  2,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  2,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  2,
  3,
  3,
  2,
  3,
  3,
  3,
  2,
  2,
  3,
  3,
  3,
  3,
  3,
  2,
  3,
  2,
  3,
  3,
  3,
  2,
  3,
  3,
  3,
  3,
  3,
  3,
  2,
  3,
  3,
  2,
  3,
  2,
  3,
  3,
  3,
  2,
  3,
  3,
  2,
  3,
  3,
  3,
  3,
  3,
  3,
  2,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  2,
  3,
  3,
  2,
  2,
  3,
  3,
  3,
  2,
  3,
  3,
  2,
  2,
  2,
  2,
  2,
  3,
  3,
  3,
  2,
  3,
  3,
  2,
  3,
  2,
  2,
  3,
  2,
  2,
  2,
  3,
  2,
  3,
  2,
  2,
  2,
  2,
  2,
  3,
  2,
  3,
  3,
  3,
  3,
  2,
  2,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  2,
  2,
  3,
  2,
  3,
  2,
  3,
  3,
  2,
  2,
  2,
  3,


In [6]:
examples, len(examples)

([('The kitchen stinks.', "I'll throw out the garbage."),
  ('The kitchen stinks.',
   "I'd hate for you to think I was lazy in returning your money."),
  ('So Dick, how about getting some coffee for tonight?',
   'Coffee? I don ’ t honestly like that kind of stuff.'),
  ('So Dick, how about getting some coffee for tonight?',
   'What ’ s wrong with that? Cigarette is the thing I go crazy for.'),
  ('So Dick, how about getting some coffee for tonight?',
   'I hated the hand-me-downs to wear when I was a kid.'),
  ('Are things still going badly with your houseguest?',
   'Getting worse. Now he ’ s eating me out of house and home. I ’ Ve tried talking to him but it all goes in one ear and out the other. He makes himself at home, which is fine. But what really gets me is that yesterday he walked into the living room in the raw and I had company over! That was the last straw.'),
  ('Are things still going badly with your houseguest?',
   'You ’ re right. Everything is probably going to com

### 按句子打包样本

In [7]:
grouped_examples = []
count = 0

for i in sample_num_memory:
    grouped_examples.append(examples[count:count+i])
    count += i
print('The group number is: ' + str(len(grouped_examples)))

The group number is: 30827


In [8]:
grouped_examples

[[('The kitchen stinks.', "I'll throw out the garbage."),
  ('The kitchen stinks.',
   "I'd hate for you to think I was lazy in returning your money.")],
 [('So Dick, how about getting some coffee for tonight?',
   'Coffee? I don ’ t honestly like that kind of stuff.'),
  ('So Dick, how about getting some coffee for tonight?',
   'What ’ s wrong with that? Cigarette is the thing I go crazy for.'),
  ('So Dick, how about getting some coffee for tonight?',
   'I hated the hand-me-downs to wear when I was a kid.')],
 [('Are things still going badly with your houseguest?',
   'Getting worse. Now he ’ s eating me out of house and home. I ’ Ve tried talking to him but it all goes in one ear and out the other. He makes himself at home, which is fine. But what really gets me is that yesterday he walked into the living room in the raw and I had company over! That was the last straw.'),
  ('Are things still going badly with your houseguest?',
   'You ’ re right. Everything is probably going to c

### 整合所有句子的正/负样本对

In [9]:
print('start generating pos and neg pairs ... ')
pos_neg_pairs = []

for i in range(len(grouped_examples)):
    if len(grouped_examples[i]) == 2:
        pos_neg_pairs.append(grouped_examples[i])
    else:
        pos_neg_pairs.append([grouped_examples[i][0], grouped_examples[i][1]])
        pos_neg_pairs.append([grouped_examples[i][0], grouped_examples[i][2]])
        pos_neg_pairs.append([grouped_examples[i][1], grouped_examples[i][2]])
        
print('there are ' + str(len(pos_neg_pairs)) + ' samples been generated...')

start generating pos and neg pairs ... 
there are 86507 samples been generated...


In [10]:
pos_neg_pairs

[[('The kitchen stinks.', "I'll throw out the garbage."),
  ('The kitchen stinks.',
   "I'd hate for you to think I was lazy in returning your money.")],
 [('So Dick, how about getting some coffee for tonight?',
   'Coffee? I don ’ t honestly like that kind of stuff.'),
  ('So Dick, how about getting some coffee for tonight?',
   'What ’ s wrong with that? Cigarette is the thing I go crazy for.')],
 [('So Dick, how about getting some coffee for tonight?',
   'Coffee? I don ’ t honestly like that kind of stuff.'),
  ('So Dick, how about getting some coffee for tonight?',
   'I hated the hand-me-downs to wear when I was a kid.')],
 [('So Dick, how about getting some coffee for tonight?',
   'What ’ s wrong with that? Cigarette is the thing I go crazy for.'),
  ('So Dick, how about getting some coffee for tonight?',
   'I hated the hand-me-downs to wear when I was a kid.')],
 [('Are things still going badly with your houseguest?',
   'Getting worse. Now he ’ s eating me out of house and h

### 分词

In [17]:
print('start tokenizing pos and neg pairs ... ')
pos_neg_inputs = []

for i in range(len(pos_neg_pairs)):
    if i == 16:
        break
        
    example1 = [pos_neg_pairs[i][0]]
    example2 = [pos_neg_pairs[i][1]]
    encoded_example1 = tokenizer.batch_encode_plus(batch_text_or_text_pairs=example1, add_special_tokens=True, max_length=256, padding='max_length', truncation=True, return_tensors='pt')
    encoded_example2 = tokenizer.batch_encode_plus(batch_text_or_text_pairs=example2, add_special_tokens=True, max_length=256, padding='max_length', truncation=True, return_tensors='pt')
    pos_neg_inputs.append([encoded_example1, encoded_example2])
    
print('there are ' + str(len(pos_neg_inputs)) + ' samples been tokenized...')

start tokenizing pos and neg pairs ... 
there are 16 samples been tokenized...


In [18]:
pos_neg_inputs

[[{'input_ids': tensor([[  101,  1996,  3829, 27136,  2015,  1012,   102,  1045,  1005,  2222,
            5466,  2041,  1996, 13044,  1012,   102,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              

In [19]:
pos_neg_inputs[0]

[{'input_ids': tensor([[  101,  1996,  3829, 27136,  2015,  1012,   102,  1045,  1005,  2222,
           5466,  2041,  1996, 13044,  1012,   102,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,   