In [2]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Flatten 
from tensorflow.keras import Model
import matplotlib.pyplot as plt

# Download a dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Batch and shuffle the data
train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train.astype('float32') / 255, y_train)).shuffle(1024).batch(32)

test_ds = tf.data.Dataset.from_tensor_slices(
    (x_test.astype('float32') / 255, y_test)).batch(32)

In [3]:
class VFLPassiveModel(Model):
    def __init__(self):
        super(VFLPassiveModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(10, name="dense1")

    def call(self, x):
        x = self.flatten(x)
        return self.d1(x)

In [26]:
class VFLActiveModel(Model):
    def __init__(self):
        super(VFLActiveModel, self).__init__()
        self.added = tf.keras.layers.Add()

    def call(self, x):
        x = self.added(x)
        return tf.keras.layers.Softmax()(x)

In [32]:
passive_model = VFLPassiveModel()
active_model = VFLActiveModel()

loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

EPOCHS = 5

for epoch in range(EPOCHS):
    # For each batch of images and labels
    for images, labels in train_ds:
        with tf.GradientTape() as passive_tape:
            # passive_model sends passive_output to active_model
            passive_output = passive_model(images)
            with tf.GradientTape() as active_tape:
                active_tape.watch(passive_output)
                active_output = active_model([passive_output, passive_output])
                loss = loss_object(labels, active_output)
            # active_model sends passive_output_gradients back to passive_model
            passive_output_gradients = active_tape.gradient(loss, passive_output)
            #print(passive_output_gradients)
            passive_output_lost = tf.multiply(passive_output, passive_output_gradients.numpy())
        passive_weight_gradients = passive_tape.gradient(passive_output_lost, passive_model.trainable_variables)
        optimizer.apply_gradients(zip(passive_weight_gradients, passive_model.trainable_variables))

        train_loss(loss)
        train_accuracy(labels, active_output)

    for test_images, test_labels in test_ds:
        passive_output = passive_model(test_images)
        active_output = active_model([passive_output, passive_output])
        t_loss = loss_object(test_labels, active_output)

        test_loss(t_loss)
        test_accuracy(test_labels, active_output)

    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch+1,
                        train_loss.result(),
                        train_accuracy.result()*100,
                        test_loss.result(),
                        test_accuracy.result()*100))

    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

Epoch 1, Loss: 0.396317720413208, Accuracy: 88.98332977294922, Test Loss: 0.2956085503101349, Test Accuracy: 91.50999450683594
Epoch 2, Loss: 0.289332777261734, Accuracy: 91.95999908447266, Test Loss: 0.28157228231430054, Test Accuracy: 91.93999481201172
Epoch 3, Loss: 0.27582597732543945, Accuracy: 92.39500427246094, Test Loss: 0.28288891911506653, Test Accuracy: 92.04000091552734
Epoch 4, Loss: 0.2672642767429352, Accuracy: 92.70500183105469, Test Loss: 0.2774609327316284, Test Accuracy: 92.06999969482422
Epoch 5, Loss: 0.26180967688560486, Accuracy: 92.7933349609375, Test Loss: 0.2739064395427704, Test Accuracy: 92.29000091552734
