In [46]:
import os
import re
import pandas as pd
import numpy as np
from math import factorial
from itertools import combinations,product

In [41]:

class Question(object):
    def __init__(self,category):
        self.category=category
        self.equalQuestions=[]
        self.notEqualQuestions=[]
    
    def addQuestion(self,example):
        if(example.label==1):
            self.equalQuestions.append(example.query2)
        else:
            self.notEqualQuestions.append(example.query2)
        
        
class CategoryMap(object):
    def __init__(self):
        self.data = {}
        
    def addExample(self,example):
        question=self.data.get(example.getKey(),Question(example.category))
        question.addQuestion(example) 
        
class InputExample(object):
    def __init__(self,category,query1,query2,label):
        self.re_punctuation='[{}]+'.format(''';'",.!?；‘’“”，。！？''')
        self.category=category
        self.query1=re.sub(self.re_punctuation, '', query1)
        self.query2=re.sub(self.re_punctuation, '', query2 )
        self.label=int(label)
    
    def getKey(self):
        return self.category+'@@'+self.query1
        
    def convert_to_features(self,tokenizer,trans=False):
        encode_data=None
        if trans:
            encode_data=tokenizer.encode_plus(self.query2,self.query1,max_length=64,pad_to_max_length=True)
        else:
            encode_data=tokenizer.encode_plus(self.query1,self.query2,max_length=64,pad_to_max_length=True)
        return InputFeatures(encode_data['input_ids'],encode_data['token_type_ids'],encode_data['attention_mask'],self.label)
    


In [32]:
for e in product(['A','B','C'],['d','e','f']):
    print(e)

('A', 'd')
('A', 'e')
('A', 'f')
('B', 'd')
('B', 'e')
('B', 'f')
('C', 'd')
('C', 'e')
('C', 'f')


In [42]:
class DataProcess(object):
    def __init__(self,data_path,tokenizer=None):
        self.data_path=data_path
        self.tokenizer=tokenizer
        
    def getTrainDataSet(self,file_name=None):
        if file_name is None:
            file_name = 'train.csv'
        examples = self._get_examples(os.path.join(self.data_path,file_name))
        features = self._get_features(examples,is_exchange=False)
        return self._get_dataset(features),len(features)
    
    def getValidDataSet(self,file_name=None):
        if file_name is None:
            file_name = 'dev.csv'
        examples = self._get_examples(os.path.join(self.data_path,file_name))
        features = self._get_features(examples,is_exchange=False)
        return self._get_dataset(features),len(features)
    
    def getTestDataSet(self,file_name=None):
        if file_name is None:
            file_name = 'test.csv'
        examples = self._get_examples(os.path.join(self.data_path,file_name))
        features = self._get_features(examples,is_exchange=False)
        return self._get_dataset(features),len(features)
    
    def savePredictData(self,file_name=None):
        if file_name is None:
            file_name = 'result.csv'
    
    def _get_examples(self,file_name):
        if os.path.exists(file_name):
            data = pd.read_csv(file_name).dropna()
            examples = []
            for i,line in data.iterrows():
                examples.append(InputExample(line['category'],line['query1'],line['query2'],line['label']))
            return examples   
        else:
            raise FileNotFoundError('{0} not found.'.format(data_path))   
    def _get_features(self,examples,is_exchange=True):
        features=[]
        for e in examples:
            features.append(e.convert_to_features(self.tokenizer,False))
            if is_exchange:
                features.append(e.convert_to_features(self.tokenizer,True))
        return features
    
    def _get_dataset(self,features):
        def gen():
            for ex in features:
                yield ({'input_ids': ex.input_ids,'attention_mask': ex.attention_mask,'token_type_ids': ex.token_type_ids},ex.label)
        return tf.data.Dataset.from_generator(gen,
                                              ({'input_ids': tf.int32,
                                                'attention_mask': tf.int32,
                                                'token_type_ids': tf.int32},
                                               tf.int64),
                                              ({'input_ids': tf.TensorShape([None]),
                                                'attention_mask': tf.TensorShape([None]),
                                                'token_type_ids': tf.TensorShape([None])},
                                               tf.TensorShape([])))

In [47]:
dataProcess = DataProcess('data')
examples = dataProcess._get_examples('data/dev.csv')
examples = examples + dataProcess._get_examples('data/train.csv')
questions = {}
for e in examples:
    question = questions.get(e.getKey(),Question(e.category,e.query1))
    question.add(e)
    questions[e.getKey()]=question

In [48]:
len(questions)

1848

In [49]:
examples = []
for value in questions.values():
    examples = examples + value.toExamples()
print(len(examples))    

31716


In [50]:
examples

[<__main__.InputExample at 0x2e59a7e3b48>,
 <__main__.InputExample at 0x2e59a7e3888>,
 <__main__.InputExample at 0x2e59a7d2f88>,
 <__main__.InputExample at 0x2e59a7d23c8>,
 <__main__.InputExample at 0x2e59a7d2708>,
 <__main__.InputExample at 0x2e59a7d2b08>,
 <__main__.InputExample at 0x2e59a7d20c8>,
 <__main__.InputExample at 0x2e59a7d2948>,
 <__main__.InputExample at 0x2e59a7d2348>,
 <__main__.InputExample at 0x2e59a7d2508>,
 <__main__.InputExample at 0x2e59a7d2488>,
 <__main__.InputExample at 0x2e59a7d2908>,
 <__main__.InputExample at 0x2e59a7d2d48>,
 <__main__.InputExample at 0x2e59a7d2a88>,
 <__main__.InputExample at 0x2e59a7d26c8>,
 <__main__.InputExample at 0x2e59a7d2d88>,
 <__main__.InputExample at 0x2e59a7d2248>,
 <__main__.InputExample at 0x2e59a7d2788>,
 <__main__.InputExample at 0x2e59a7d2f48>,
 <__main__.InputExample at 0x2e59a7d2988>,
 <__main__.InputExample at 0x2e59a7d2108>,
 <__main__.InputExample at 0x2e59a7d2288>,
 <__main__.InputExample at 0x2e59a7d28c8>,
 <__main__.