# Model recovery attack in split learning with multiple data owners

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
def make_dataset(X, Y, f):
    x = tf.data.Dataset.from_tensor_slices(X)
    y = tf.data.Dataset.from_tensor_slices(Y)
    x = x.map(f)
    xy = tf.data.Dataset.zip((x, y))
    xy = xy.shuffle(10000)
    return xy

df = pd.read_csv('../datasets/german-credit.csv', index_col=0).sample(frac=1)
min_values = df.drop(columns=["kredit"]).describe().transpose()['min'].to_numpy()
max_values = df.drop(columns=["kredit"]).describe().transpose()['max'].to_numpy()
x = df.drop(columns=["kredit"]).to_numpy()
x = (x-min_values)/(max_values-min_values)
y = df["kredit"].to_numpy().reshape((len(x), 1)).astype("float32")
# x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
# attack_ds = make_dataset(x_train, y_train, lambda t: t)
train_ds = make_dataset(x, y, lambda t: t)

train_size = len(x)

In [3]:
num_classes = 2

def make_f(input_shape):
    xin = tf.keras.layers.Input(input_shape)
    x = tf.keras.layers.BatchNormalization()(xin)
    x = tf.keras.layers.Dense(32, activation="relu")(x)
    x = tf.keras.layers.Dense(64, activation="relu")(x)
    x = tf.keras.layers.Dense(128, activation="relu")(x)
    output = tf.keras.layers.Dense(256, activation="relu")(x)
    return tf.keras.Model(xin, output)

def make_g(input_shape, class_num):
    xin = tf.keras.layers.Input(input_shape)
    x = tf.keras.layers.Dense(512, activation="relu")(xin)
    # x = tf.keras.layers.Dropout(0.5)(x)
    output = tf.keras.layers.Dense(1, activation="sigmoid")(x)
    return tf.keras.Model(xin, output)

input_shape = train_ds.element_spec[0].shape
f = make_f(input_shape)
intermediate_shape = f.layers[-1].output_shape[1:]
g = make_g(intermediate_shape, num_classes)

In [4]:
batch_size = 128
epoches = 5
# note that iterations is the number of batches we iterate
iterations = epoches * train_size // batch_size
learning_rate = 0.001

In [5]:
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
train_batches = train_ds.batch(batch_size=batch_size, drop_remainder=True).repeat(-1).take(iterations)
train_ref = []
z_ref = []

log = []
iter_count = 0
log_frequency = 100

for (x_batch, y_batch) in train_batches:
    
    with tf.GradientTape(persistent=True) as tape:
        z = f(x_batch, training=True)
        y_pred = g(z, training = True)
        if num_classes == 2:
            loss = tf.keras.losses.binary_crossentropy(y_true=y_batch, y_pred=y_pred)
            acc = tf.keras.metrics.binary_accuracy(y_batch, y_pred)
        else:
            loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y_batch, y_pred=y_pred)
            acc = tf.metrics.sparse_categorical_accuracy(y_batch, y_pred)
    var = f.trainable_variables + g.trainable_variables
    grad = tape.gradient(loss, var)
    optimizer.apply_gradients(zip(grad, var))

    iter_loss = sum(loss) / len(loss)
    iter_acc = sum(acc) / len(acc)
    log.append([iter_loss, iter_acc])
    iter_count += 1

    train_ref.append(x_batch)
    z_ref.append(z)

    if (iter_count - 1) % log_frequency == 0:
        print("Iteration %04d: Training loss: %0.4f training accuracy: %0.4f" % (iter_count, iter_loss, iter_acc))

Iteration 0001: Training loss: 0.6799 training accuracy: 0.6875


In [6]:
def make_generator(input_shape):
    xin = tf.keras.layers.Input(input_shape)
    act = "relu"
    x = tf.keras.layers.Dense(512, activation=act)(xin)
    x = tf.keras.layers.Dense(256, activation=act)(x)
    x = tf.keras.layers.Dense(128, activation=act)(x)
    x = tf.keras.layers.Dense(64, activation=act)(x)
    x = tf.keras.layers.Dense(32, activation=act)(x)
    x = tf.keras.layers.Dense(19, activation="sigmoid")(x)
    return tf.keras.Model(xin, x)

def make_random_generator(input_shape):
    xin = tf.keras.layers.Input(input_shape)
    act = "relu"
    x = tf.keras.layers.Dense(512, activation=act)(xin)
    x = tf.keras.layers.Dense(256, activation=act)(x)
    x = tf.keras.layers.Dense(128, activation=act)(x)
    x = tf.keras.layers.Dense(64, activation=act)(x)
    x = tf.keras.layers.Dense(32, activation=act)(x)
    x = tf.keras.layers.Dense(19, activation="sigmoid")(x)
    return tf.keras.Model(xin, x)

In [7]:
f_temp = tf.keras.models.clone_model(f)
f_temp.set_weights(f.get_weights())

generator = make_generator(intermediate_shape)
# generator = make_random_generator((128,))

x_opt = tf.keras.optimizers.Adam(learning_rate=0.001)
f_opt = tf.keras.optimizers.Adam(learning_rate=0.00001)

# inference_batches = attack_ds.batch(batch_size=32, drop_remainder=True).repeat(-1).take(attack_iterations)

attack_iter_count = 0

# for (x_batch, y_batch) in inference_batches:
for i in range(10):

    z = z_ref[iterations - i - 1]
    x = train_ref[iterations - i - 1]

    # x_temp = tf.Variable(2 * np.random.rand(*(x.numpy().shape)) - 1)

    # x_temp = np.zeros_like(x.numpy())
    # x_temp.fill(0.5)
    # x_temp = tf.Variable(x_temp)

    for _ in range(200):

        for _ in range(20):
            with tf.GradientTape() as tape:
                x_temp = generator(z, training=True)
                # x_temp = generator(tf.concat([z, tf.constant(np.random.rand(*(z.numpy().shape)).astype("float32"))],1), training=True)
                loss_x = tf.keras.losses.MeanSquaredError()(f_temp(x_temp, training=False), z)
            vars = generator.trainable_variables
            grad = tape.gradient(loss_x, vars)
            x_opt.apply_gradients(zip(grad, vars))
            # loss = lambda: tf.keras.losses.MeanSquaredError()(f_temp(x_temp, training=False), z)
            # x_opt.minimize(loss, var_list=[x_temp])

        for _ in range(1):
            with tf.GradientTape() as tape:
                loss_f = tf.keras.losses.MeanSquaredError()(f_temp(x_temp, training=True), z)
            vars = f_temp.trainable_variables
            grad = tape.gradient(loss_f, vars)
            f_opt.apply_gradients(zip(grad, vars))

    attack_mse = tf.losses.MeanSquaredError()(x_temp, x)
    rg_uniform = tf.losses.MeanSquaredError()(x, np.random.rand(*(x.numpy().shape)))
    attack_iter_count += 1
    print("Iteration %04d: RG: %0.4f reconstruction validation: %0.4f" % (attack_iter_count, rg_uniform, attack_mse))

Iteration 0001: RG: 0.2381 reconstruction validation: 0.1719
Iteration 0002: RG: 0.2353 reconstruction validation: 0.1035
Iteration 0003: RG: 0.2425 reconstruction validation: 0.0881
Iteration 0004: RG: 0.2391 reconstruction validation: 0.0785
Iteration 0005: RG: 0.2379 reconstruction validation: 0.0706
Iteration 0006: RG: 0.2365 reconstruction validation: 0.0681
Iteration 0007: RG: 0.2338 reconstruction validation: 0.0627
Iteration 0008: RG: 0.2422 reconstruction validation: 0.0581
Iteration 0009: RG: 0.2337 reconstruction validation: 0.0542
Iteration 0010: RG: 0.2426 reconstruction validation: 0.0611
