In [1]:
import keras
import numpy as np
import matplotlib.pyplot as plt
 
import random
from keras.datasets import mnist
from keras.models import Model
from keras import layers
from keras.optimizers import RMSprop
from keras import backend as K

### 定义loss

In [2]:
# 由于网络结构中将x_a，x_p，x_n直接拼接在了一起，所以使用了下标进行分割。在triplet loss中不需要使用y_true的值
def triplet_loss(y_true, y_pred, alpha = 0.4):
    total_lenght = y_pred.shape.as_list()[-1]
    anchor = y_pred[:,0:int(total_lenght*1/3)]
    positive = y_pred[:,int(total_lenght*1/3):int(total_lenght*2/3)]
    negative = y_pred[:,int(total_lenght*2/3):int(total_lenght*3/3)]
 
    # distance between the anchor and the positive
    pos_dist = K.sum(K.square(anchor-positive),axis=1)
 
    # distance between the anchor and the negative
    neg_dist = K.sum(K.square(anchor-negative),axis=1)
 
    # compute loss
    basic_loss = pos_dist-neg_dist+alpha
    loss = K.maximum(basic_loss,0.0)
 
    return loss

### 样本对的构造

In [3]:
def create_triple(x_train,y_train):
    x_anchors=[]
    x_positives=[]
    x_negatives=[]
    for i in range(0, x_train.shape[0]):
        #随机选择一个样本x
        random_index = random.randint(0, x_train.shape[0] - 1)
        x_anchor = x_train[random_index]
        y = y_train[random_index]
        
        #随机选择一个与x相同类型的样本x+
        indices_for_pos = np.squeeze(np.where(y_train == y))
        x_positive = x_train[indices_for_pos[random.randint(0, len(indices_for_pos) - 1)]]
        
        #随机选择一个与x不同同类型的样本x-
        indices_for_neg = np.squeeze(np.where(y_train != y))
        x_negative = x_train[indices_for_neg[random.randint(0, len(indices_for_neg) - 1)]]
        
        x_anchors.append(x_anchor)
        x_positives.append(x_positive)
        x_negatives.append(x_negative)
        
    return np.array(x_anchors), np.array(x_positives), np.array(x_negatives)

### 特征提取网络

In [None]:
def create_base_network(input_shape):
    '''Base network to be shared (eq. to feature extraction).
    '''
    base_num=16
    In1 = layers.Input((28,28,1))
    x=layers.Conv2D(8, (3, 3), padding='same', activation='relu')(In1)
    x=layers.MaxPooling2D(pool_size=(2, 2))(x)
    x=layers.Conv2D(8, (3, 3), padding='same', activation='relu')(x)
    x=layers.MaxPooling2D(pool_size=(2, 2))(x)
    x=layers.Conv2D(8, (3, 3), padding='same', activation='relu')(x)
    x=layers.Conv2D(base_num, (3, 3), padding='same', activation='relu')(x)
    x=layers.MaxPooling2D(pool_size=(2, 2))(x)
    x=layers.Conv2D(base_num*2, (3, 3), padding='same', activation='relu')(x)
    x=layers.Flatten()(x)
    x = layers.Dense(40, activation='relu')(x)
    x = layers.Dropout(0.1)(x)
    x = layers.Dense(4, activation='relu')(x)
    return Model(In1, x)

### 加载数据

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1,28,28,1).astype('float32')
x_test = x_test.reshape(-1,28,28,1).astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]
 
train_a_pairs,train_p_pairs,train_n_pairs=create_triple(x_train,y_train)
test_a_pairs,test_p_pairs,test_n_pairs=create_triple(x_test,y_test)

### 构建网络和训练

In [None]:
base_network = create_base_network(input_shape)
 
input_a = layers.Input(shape=input_shape)
input_p = layers.Input(shape=input_shape)
input_n = layers.Input(shape=input_shape)
 
# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_p = base_network(input_p)
processed_n = base_network(input_n)
merged_vector = layers.concatenate([processed_a, processed_p, processed_n], axis=-1, name='merged_layer')
model = Model([input_a,input_p, input_n], merged_vector)
keras.utils.plot_model(model,"triplet_Model.png",show_shapes=True)
model.summary()
 
# train
rms = RMSprop()
model.compile(loss=triplet_loss, optimizer=rms)
tr_y = np.empty((train_a_pairs.shape[0],1))
te_y = np.empty((test_a_pairs.shape[0],1))
history=model.fit([train_a_pairs,train_p_pairs,train_n_pairs], tr_y,
          batch_size=128,
          epochs=20,verbose=2,
          validation_data=([test_a_pairs,test_p_pairs,test_n_pairs], te_y))

### 可视化

In [None]:
import seaborn as sns
import matplotlib.patheffects as PathEffects
from sklearn.manifold import TSNE
# Define our own plot function
def scatter(x, labels, subtitle=None):
    # We choose a color palette with seaborn.
    palette = np.array(sns.color_palette("hls", 10))
 
    # We create a scatter plot.
    f = plt.figure(figsize=(8, 8))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40,
                    c=palette[labels.astype(np.int)])
    plt.xlim(-25, 25)
    plt.ylim(-25, 25)
    ax.axis('off')
    ax.axis('tight')
 
    # We add the labels for each digit.
    txts = []
    for i in range(10):
        # Position of each label.
        xtext, ytext = np.median(x[labels == i, :], axis=0)
        txt = ax.text(xtext, ytext, str(i), fontsize=24)
        txt.set_path_effects([
            PathEffects.Stroke(linewidth=5, foreground="w"),
            PathEffects.Normal()])
        txts.append(txt)
        
    if subtitle != None:
        plt.suptitle(subtitle)
        
    plt.savefig(subtitle)
tsne = TSNE()
X_train_trm = base_network.predict(x_train[:512].reshape(-1,28,28,1))
X_test_trm = base_network.predict(x_test[:512].reshape(-1,28,28,1))
train_tsne_embeds = tsne.fit_transform(X_train_trm)
eval_tsne_embeds = tsne.fit_transform(X_test_trm)
scatter(train_tsne_embeds, y_train[:512], "Training Data After TNN")
scatter(eval_tsne_embeds, y_test[:512], "Validation Data After TNN")