In [1]:
import re
import os
import tensorflow as tf
import numpy as np
import pandas as pd
from transformers import XLNetTokenizer, TFXLNetModel,TFXLNetForSequenceClassification

In [2]:
model = TFXLNetForSequenceClassification.from_pretrained('../model/xlnet/tf_zh')
model.logits_proj.activation=tf.keras.activations.sigmoid
model.transformer.trainable=False

In [3]:
model.summary()

Model: "tfxl_net_for_sequence_classification"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
transformer (TFXLNetMainLaye multiple                  116718336 
_________________________________________________________________
sequence_summary (TFSequence multiple                  590592    
_________________________________________________________________
logits_proj (Dense)          multiple                  769       
Total params: 117,309,697
Trainable params: 591,361
Non-trainable params: 116,718,336
_________________________________________________________________


In [4]:
tokenizer = XLNetTokenizer.from_pretrained('../model/xlnet/tf_zh')

In [5]:
class InputFeatures(object):
    def __init__(self,input_ids,token_type_ids,attention_mask,label):
        self.input_ids=input_ids
        self.token_type_ids=token_type_ids
        self.attention_mask=attention_mask 
        self.label=int(label)
        
class InputExample(object):
    def __init__(self,category,query1,query2,label):
        self.re_punctuation='[{}]+'.format(''';'",.!?；‘’“”，。！？''')
        self.category=category
        self.query1=re.sub(self.re_punctuation, '', query1)
        self.query2=re.sub(self.re_punctuation, '', query2 )
        self.label=int(label)
        
    def convert_to_features(self,trans=False):
        encode_data=None
        if trans:
            encode_data=tokenizer.encode_plus(self.query2,self.query1,max_length=64,pad_to_max_length=True)
        else:
            encode_data=tokenizer.encode_plus(self.query1,self.query2,max_length=64,pad_to_max_length=True)
#         return model(inputs=encode_data['input_ids'],
#                      attention_mask=encode_data['attention_mask'],
#                      token_type_ids=encode_data['token_type_ids'])[0],self.label
        return InputFeatures(encode_data['input_ids'],encode_data['token_type_ids'],encode_data['attention_mask'],self.label)

        
def read_file(data_path):
    if os.path.exists(data_path):
        return pd.read_csv(data_path).dropna()
    else:
        raise FileNotFoundError('{0} not found.'.format(data_path))

def get_examples(data_path):
    examples = []
    for i,line in read_file(data_path).iterrows():
        examples.append(InputExample(line['category'],line['query1'],line['query2'],line['label']))
    return examples

def get_features(examples):
    features=[]
    for e in examples:
        features.append(e.convert_to_features(False))
        features.append(e.convert_to_features(True))
    return features

def get_dataset(features):
    def gen():
        for ex in features:
            yield ({'input_ids': ex.input_ids,'attention_mask': ex.attention_mask,'token_type_ids': ex.token_type_ids},ex.label)
    return tf.data.Dataset.from_generator(gen,
                                          ({'input_ids': tf.int32,
                                            'attention_mask': tf.int32,
                                            'token_type_ids': tf.int32},
                                           tf.int64),
                                          ({'input_ids': tf.TensorShape([None]),
                                            'attention_mask': tf.TensorShape([None]),
                                            'token_type_ids': tf.TensorShape([None])},
                                           tf.TensorShape([])))

In [6]:
train_data = get_examples('data/train.csv')
dev_data = get_examples('data/dev.csv')

In [7]:
train_features = get_features(train_data)
dev_features = get_features(dev_data)

In [8]:
train_steps = len(train_features) // 64
valid_steps = len(dev_features) // 64

In [9]:
train_steps

273

In [10]:
valid_steps

62

In [11]:
train_dataset = get_dataset(train_features)
dev_dataset = get_dataset(dev_features)

In [12]:
train_dataset = train_dataset.shuffle(128).batch(64).repeat(-1)
dev_dataset = dev_dataset.shuffle(128).batch(64).repeat(-1)

In [13]:
train_dataset

<RepeatDataset shapes: ({input_ids: (None, None), attention_mask: (None, None), token_type_ids: (None, None)}, (None,)), types: ({input_ids: tf.int32, attention_mask: tf.int32, token_type_ids: tf.int32}, tf.int64)>

In [14]:
dev_dataset

<RepeatDataset shapes: ({input_ids: (None, None), attention_mask: (None, None), token_type_ids: (None, None)}, (None,)), types: ({input_ids: tf.int32, attention_mask: tf.int32, token_type_ids: tf.int32}, tf.int64)>

In [15]:
model(tokenizer.encode_plus('阿斯蒂芬','大师傅',max_length=64,pad_to_max_length=True,return_tensors='tf'))

(<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.50964546]], dtype=float32)>,)

In [16]:
tokenizer.decode(tokenizer.encode('阿斯蒂芬','大师傅'))

'阿斯蒂芬<sep> 大师傅<sep><cls>'

In [17]:
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])

In [18]:
model.transformer.trainable=False

In [20]:
model.fit(train_dataset,
          epochs=6,
          steps_per_epoch=train_steps,
          validation_data=dev_dataset,
          validation_steps=valid_steps,
          verbose=2)

Train for 273 steps, validate for 62 steps
Epoch 1/6
273/273 - 62s - loss: 0.4891 - accuracy: 0.7660 - val_loss: 0.4188 - val_accuracy: 0.8143
Epoch 2/6
273/273 - 62s - loss: 0.4946 - accuracy: 0.7599 - val_loss: 0.4204 - val_accuracy: 0.8170
Epoch 3/6
273/273 - 62s - loss: 0.4862 - accuracy: 0.7684 - val_loss: 0.4134 - val_accuracy: 0.8148
Epoch 4/6
273/273 - 62s - loss: 0.4857 - accuracy: 0.7649 - val_loss: 0.4195 - val_accuracy: 0.8057
Epoch 5/6
273/273 - 62s - loss: 0.4717 - accuracy: 0.7765 - val_loss: 0.4135 - val_accuracy: 0.8032
Epoch 6/6
273/273 - 62s - loss: 0.4869 - accuracy: 0.7721 - val_loss: 0.4052 - val_accuracy: 0.8178


<tensorflow.python.keras.callbacks.History at 0x24216bb4488>