In [1]:
import re
import os
import tensorflow as tf
import numpy as np
import pandas as pd
from transformers import XLNetTokenizer, TFXLNetForSequenceClassification

In [3]:
model = TFXLNetForSequenceClassification.from_pretrained('model/xlnet')

In [4]:
tokenizer = XLNetTokenizer.from_pretrained('model/xlnet')

In [5]:
class InputFeatures(object):
    def __init__(self,input_ids,token_type_ids,attention_mask,label):
        self.input_ids=input_ids
        self.token_type_ids=token_type_ids
        self.attention_mask=attention_mask 
        self.label=int(label)
        
class InputExample(object):
    def __init__(self,category,query1,query2,label):
        self.re_punctuation='[{}]+'.format(''';'",.!?；‘’“”，。！？''')
        self.category=category
        self.query1=re.sub(self.re_punctuation, '', query1)
        self.query2=re.sub(self.re_punctuation, '', query2 )
        self.label=int(label)
        
    def convert_to_features(self,trans=False):
        encode_data=None
        if trans:
            encode_data=tokenizer.encode_plus(self.query2,self.query1,max_length=64,pad_to_max_length=True)
        else:
            encode_data=tokenizer.encode_plus(self.query1,self.query2,max_length=64,pad_to_max_length=True)
        return InputFeatures(encode_data['input_ids'],encode_data['token_type_ids'],encode_data['attention_mask'],self.label)

        
def read_file(data_path):
    if os.path.exists(data_path):
        return pd.read_csv(data_path).dropna()
    else:
        raise FileNotFoundError('{0} not found.'.format(data_path))

def get_examples(data_path):
    examples = []
    for i,line in read_file(data_path).iterrows():
        examples.append(InputExample(line['category'],line['query1'],line['query2'],line['label']))
    return examples

def get_features(examples):
    features=[]
    for e in examples:
        features.append(e.convert_to_features(False))
        features.append(e.convert_to_features(True))
    return features

def get_dataset(features):
    def gen():
        for ex in features:
            yield ({'input_ids': ex.input_ids,'attention_mask': ex.attention_mask,'token_type_ids': ex.token_type_ids},ex.label)
    return tf.data.Dataset.from_generator(gen,
                                          ({'input_ids': tf.int32,
                                            'attention_mask': tf.int32,
                                            'token_type_ids': tf.int32},
                                           tf.int32),
                                          ({'input_ids': tf.TensorShape([None]),
                                            'attention_mask': tf.TensorShape([None]),
                                            'token_type_ids': tf.TensorShape([None])},
                                           tf.TensorShape([])))

In [6]:
train_data = get_examples('data/train.csv')
dev_data = get_examples('data/dev.csv')

In [7]:
train_features = get_features(train_data)
dev_features = get_features(dev_data)

In [23]:
train_dataset = get_dataset(train_features)
dev_dataset = get_dataset(dev_features)

In [24]:
train_dataset = train_dataset.shuffle(128).batch(64).repeat(-1)
dev_dataset = dev_dataset.shuffle(128).batch(64).repeat(-1)

In [25]:
train_dataset

<RepeatDataset shapes: ({input_ids: (None, None), attention_mask: (None, None), token_type_ids: (None, None)}, (None,)), types: ({input_ids: tf.int32, attention_mask: tf.int32, token_type_ids: tf.int32}, tf.int32)>

In [18]:
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])

In [20]:
model.transformer.trainable=False
train_steps = len(train_features)//64
valid_steps = len(dev_features)//64

In [26]:
model.fit(train_dataset,
          epochs=2,
          steps_per_epoch=train_steps,
          validation_data=dev_dataset,
          validation_steps=valid_steps,
          verbose=2)

Train for 273 steps, validate for 62 steps
Epoch 1/2
273/273 - 104s - loss: 3.0236 - accuracy: 0.5975 - val_loss: 0.5208 - val_accuracy: 0.7767
Epoch 2/2
273/273 - 105s - loss: 0.5970 - accuracy: 0.7282 - val_loss: 0.4780 - val_accuracy: 0.7870


<tensorflow.python.keras.callbacks.History at 0x1928c094088>

In [None]:
predict_data = get_examples('data/train.csv')
predict_features = get_features(predict_data)
predict_dataset = get_dataset(predict_features)

In [None]:
model.predict()