In [1]:
import tensorflow as tf
import numpy as np
import tensorflow.keras as keras
from tensorflow.keras import layers
import pickle
from time import gmtime, strftime

In [2]:
base_dir = './data/'
data_config = pickle.load(open(base_dir + 'data_config.dict', 'rb'))
MAX_LEN = data_config['MAX_LEN']
WORD_DIM = data_config['WORD_DIM']
RELATION_NUM = data_config['RELATION_NUM']
POS_MIN = -100
POS_EMBED_LEN = 200

In [3]:
word_embed = pickle.load(open(base_dir + 'word_embed', 'rb'))
word_embed = np.transpose(word_embed)
PAD_ID = word_embed.shape[0]-1
print(word_embed.shape, PAD_ID)

(22549, 300) 22548


In [4]:
rel2lb = pickle.load(open(base_dir + 'rel2lb.dict', 'rb'))
OTHER_LABEL = rel2lb['Other']

In [5]:
# load dataset from tfrecord

def pad_fixed_length(words):
    words = tf.pad(words,tf.constant([[0, MAX_LEN-words.shape[0]]]), constant_values=PAD_ID)
    return words

def processing(raw):
    features = tf.io.parse_single_example(
        raw,
        features={
            'idxs': tf.io.FixedLenFeature([2], tf.int64),
            'label': tf.io.FixedLenFeature([1], tf.int64),
            'words': tf.io.VarLenFeature(tf.int64)
        }
    )
    idxs = tf.cast(features['idxs'], tf.int32)
    pos1 = tf.range(0, MAX_LEN, 1, dtype=tf.int32) - idxs[0]
    pos2 = tf.range(0, MAX_LEN, 1, dtype=tf.int32) - idxs[1]
    pos1 = pos1 - POS_MIN
    pos2 = pos2 - POS_MIN
    label = tf.cast(features['label'], tf.int32)
    words = tf.cast(tf.sparse.to_dense(features['words']), tf.int32)
    words = tf.py_function(pad_fixed_length, [words], Tout=tf.int32)
    return pos1, pos2, label, words

train_ds = tf.data.TFRecordDataset(filenames = [base_dir + 'train.tfrecords']).map(processing).shuffle(2000).batch(128)
test_ds = tf.data.TFRecordDataset(filenames = [base_dir + 'test.tfrecords']).map(processing).shuffle(2000).batch(128)

In [6]:
class Network(keras.Model):
    def __init__(self, word_embed, pos_dim, word_dim, class_num, class_dim, other_label, hyperparams):
        super(Network, self).__init__()
        self.pos_dim = pos_dim
        self.word_dim = word_dim
        self.other_label = tf.constant(other_label, dtype=tf.int32)
        self.m_neg = tf.constant(hyperparams['m_neg'])
        self.m_pos = tf.constant(hyperparams['m_pos'])
        self.gamma = tf.constant(hyperparams['gamma'])
        self.beta = tf.constant(hyperparams['beta'])
        self.word_embed = tf.Variable(word_embed, dtype=tf.float32)
        self.pos1_embed = tf.Variable(tf.random.uniform([POS_EMBED_LEN, pos_dim],minval=0,maxval=1), dtype=tf.float32)
        self.pos2_embed = tf.Variable(tf.random.uniform([POS_EMBED_LEN, pos_dim],minval=0,maxval=1), dtype=tf.float32)
        class_embed_init_param = tf.sqrt(6/(class_num + class_dim))
        self.class_matrix = tf.Variable(tf.random.uniform([class_dim, class_num],minval=-class_embed_init_param, maxval=class_embed_init_param), dtype=tf.float32) 
        
        
        ctx1 = np.zeros([MAX_LEN, MAX_LEN])
        for i in range(0,MAX_LEN-1):
            ctx1[i+1, i] = 1
            
        ctx2 = np.zeros([MAX_LEN, MAX_LEN])
        for i in range(1,MAX_LEN):
            ctx2[i-1, i] = 1
        
        self.ctx_mat1 = tf.constant(ctx1, dtype=tf.float32)
        self.ctx_mat2 = tf.constant(ctx2, dtype=tf.float32)
        
        self.conv1 = layers.Conv1D(class_dim, 1, padding='same', activation=tf.nn.tanh)
        self.pool1 = layers.MaxPool1D(MAX_LEN, padding='same')
        
    
    @staticmethod
    @tf.function(experimental_relax_shapes=True)
    def other_loss_func(score, m_neg, gamma):
        return tf.math.log(1.0 + tf.exp(gamma*(m_neg + tf.reduce_max(score, axis=1))))
    
    @staticmethod
    @tf.function
    def remove_ele(data):
        label = tf.cast(data[-1], dtype=tf.int32)
        row = data[:-1]
        return tf.concat([row[:label],row[label+1:]], axis=0)
    
    @staticmethod
    @tf.function(experimental_relax_shapes=True)
    def class_loss_func(score, class_label, m_neg, m_pos, gamma):
        
        first_term = tf.math.log(1.0 + tf.exp(gamma*(m_pos - tf.gather_nd(score, tf.expand_dims(class_label, axis=1), batch_dims=1 ))))
        first_term = tf.squeeze(first_term, axis=1)
        score_temp = tf.map_fn(Network.remove_ele, tf.concat([score, tf.cast(class_label, dtype=tf.float32)], axis=1), dtype=tf.float32)
        second_term = tf.math.log(1.0 + tf.exp(gamma*(m_neg + tf.reduce_max(score_temp, axis=1))))
        return first_term + second_term

    
    def call(self, inputs, training):
        pos1 = inputs[0]
        pos2 = inputs[1]
        label = inputs[2]
        words = inputs[3]
        pf1 = tf.nn.embedding_lookup(self.pos1_embed, pos1)
        pf2 = tf.nn.embedding_lookup(self.pos2_embed, pos2)
        wf = tf.nn.embedding_lookup(self.word_embed, words)
        wf = tf.concat([pf1, pf2, wf], axis=2)
        wf_before = tf.matmul(self.ctx_mat1, wf)
        wf_follow = tf.matmul(self.ctx_mat2, wf)
        wf_context = tf.concat([wf_before, wf, wf_follow], axis=2)
        
        x = self.conv1(wf_context)
        x = self.pool1(x)
        
        x = tf.matmul(x, self.class_matrix)
        x = tf.squeeze(x, axis=1)
        
        if training:
            other_mask = tf.squeeze(tf.equal(label, self.other_label), axis=1)
            class_mask = tf.math.logical_not(other_mask)
            other_scores = tf.boolean_mask(x, other_mask)
            class_scores = tf.boolean_mask(x, class_mask)
            class_label = tf.boolean_mask(label, class_mask)
            loss1 = Network.other_loss_func(other_scores, self.m_neg, self.gamma)
            loss2 = Network.class_loss_func(class_scores, class_label, self.m_neg, self.m_pos, self.gamma)
            loss_sum = tf.reduce_sum(tf.concat([loss1,loss2], axis=0))
            theta_loss = tf.keras.regularizers.l2(self.beta)(self.word_embed) + tf.keras.regularizers.l2(self.beta)(self.pos1_embed) + tf.keras.regularizers.l2(self.beta)(self.pos2_embed)\
                + tf.keras.regularizers.l2(self.beta)(self.class_matrix) + tf.keras.regularizers.l2(self.beta)(self.conv1.weights[0])
            loss_sum = theta_loss + loss_sum
            loss_avg = loss_sum / tf.cast(tf.shape(label)[0], tf.float32)
            return x, loss_avg
        else:
            return x
        

In [7]:
pos_dim = 70
class_dim = 1000
hyperparams={
    'm_neg':0.5,
    'm_pos':3.,
    'gamma':2.0,
    'beta':0.001
}
network = Network(word_embed, pos_dim, WORD_DIM, RELATION_NUM-1, class_dim, OTHER_LABEL, hyperparams)

In [8]:
optimizer = keras.optimizers.Adam(learning_rate=1e-3)

In [9]:
        
def train_func(dataset):
    step_count = 0
    loss_count = 0
    for step, (pos1, pos2, label, words) in enumerate(dataset):
        with tf.GradientTape() as tape:
            scores,loss = network([pos1, pos2, label, words], training=True)
        grads = tape.gradient(loss, network.trainable_variables)
        optimizer.apply_gradients(zip(grads, network.trainable_variables))
        loss_count += loss
        step_count += 1
    loss_avg = loss_count / step_count
    return loss_avg

def test_func(dataset):
    step_count = 0
    pred_all = []
    label_all = []
    correct = 0
    for step, (pos1, pos2, label, words) in enumerate(dataset):
        scores = network([pos1, pos2, label, words], training=False)
        pred = []
        for item in scores:
            if tf.reduce_max(item) < 0:
                pred.append(OTHER_LABEL)
            else:
                pred.append(tf.argmax(item).numpy())
        label_all.extend(list(label.numpy()))
        pred_all.extend(pred)
    for p, t in zip(pred_all, label_all):
        if p == t:
            correct += 1
            
    accuracy = correct / len(label_all)
    return accuracy, pred_all, label_all

In [10]:
best_acc = 0
save_dir = './saved/'
for epoch in range(10):
    loss_avg = train_func(train_ds)
    train_accuracy, pred_all, label_all = test_func(train_ds)
    test_accuracy, pred_all, label_all = test_func(test_ds)
    if test_accuracy > best_acc:
        network.save_weights(save_dir + 'bestckpt')
        best_acc = test_accuracy
    print('training loss is {0: .4f}, training accuracy is {1: .4f}, test accuracy is {2: .4f}'.format(loss_avg, train_accuracy, test_accuracy))

Instructions for updating:
Use tf.identity instead.
training loss is  16.7153, training accuracy is  0.6823, test accuracy is  0.6194
training loss is  12.5337, training accuracy is  0.7883, test accuracy is  0.7052
training loss is  10.2890, training accuracy is  0.8946, test accuracy is  0.7413
training loss is  8.6682, training accuracy is  0.9477, test accuracy is  0.7563
training loss is  7.4365, training accuracy is  0.9808, test accuracy is  0.7622
training loss is  6.5613, training accuracy is  0.9940, test accuracy is  0.7641
training loss is  5.9435, training accuracy is  0.9989, test accuracy is  0.7656
training loss is  5.4464, training accuracy is  0.9998, test accuracy is  0.7622
training loss is  5.0255, training accuracy is  1.0000, test accuracy is  0.7652
training loss is  4.6547, training accuracy is  1.0000, test accuracy is  0.7597


In [11]:
network.load_weights(save_dir + 'bestckpt')
accuracy, pred_all, label_all = test_func(test_ds)
label_all = [item[0] for item in label_all]
print(accuracy)

0.7655502392344498


In [12]:
import pickle
id2rel = pickle.load(open(base_dir + 'lb2rel.dict', 'rb'))
unique_relations = pickle.load(open(base_dir + 'unique_relations', 'rb'))
label_all = [id2rel[item] for item in label_all]
pred_all = [id2rel[item] for item in pred_all]

In [13]:
from sklearn.metrics import classification_report
print(classification_report(label_all, pred_all, labels = [item for item in unique_relations if item !='Other']))

                    precision    recall  f1-score   support

Entity-Destination       0.85      0.90      0.88       292
     Entity-Origin       0.83      0.78      0.81       258
 Content-Container       0.81      0.84      0.82       192
     Message-Topic       0.76      0.94      0.84       261
  Product-Producer       0.73      0.74      0.74       231
 Member-Collection       0.78      0.92      0.84       233
      Cause-Effect       0.91      0.90      0.90       328
 Instrument-Agency       0.71      0.73      0.72       156
   Component-Whole       0.86      0.75      0.80       312

         micro avg       0.81      0.84      0.83      2263
         macro avg       0.80      0.83      0.82      2263
      weighted avg       0.82      0.84      0.83      2263

