In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"

In [None]:
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

tf.config.threading.set_inter_op_parallelism_threads(16)
tf.config.experimental.set_synchronous_execution(False)
tf.config.optimizer.set_jit(True)
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
BATCH_SIZE = 320

In [None]:
class TimeHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []
    def on_epoch_begin(self, epoch, logs={}):
        self.epoch_time_start = time.time()
    def on_epoch_end(self, epoch, logs={}):
        self.times.append(time.time() - self.epoch_time_start)
        
def normalize(ndarray):
    ndarray = ndarray.astype("float32")
    ndarray = (ndarray/127.5) - 1
    return ndarray

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
num_classes = np.max(y_train) + 1
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
x_train = normalize(x_train)
x_test = normalize(x_test)
num_train = x_train.shape[0]

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder=True)
train_dataset = train_dataset.prefetch(128)
_ = train_dataset.take(8)

steps_per_epoch = int(num_train / BATCH_SIZE / 2) + 1

In [None]:
def cnn_block():
    cnn_block = tf.keras.models.Sequential([
        layers.Conv2D(128, (3,3), padding="same", activation="relu", kernel_initializer="he_uniform"),
        layers.Conv2D(512, (3,3), padding="same", activation="relu", kernel_initializer="he_uniform"),
        layers.Conv2D(512, (3,3), padding="same", activation="relu", kernel_initializer="he_uniform"),
        layers.Conv2D(128, (3,3), padding="same", activation="relu", kernel_initializer="he_uniform"),
        layers.BatchNormalization(fused=True),
    ])
    return cnn_block

class PipelineCNN(tf.keras.Model):
    def __init__(self, splits=1):
        super(PipelineCNN, self).__init__()
        self.splits = splits
        with tf.device('/GPU:0'):
            self.conv_0 = layers.Conv2D(32, (3,3), padding="same", activation="relu", kernel_initializer="he_uniform")
            self.maxpool_1 = layers.MaxPooling2D((2,2))
            self.block_1 = ret_cnn_block()
        with tf.device('/GPU:1'):
            self.block_2 = ret_cnn_block()
            self.block_3 = ret_cnn_block()
            self.block_4 = ret_cnn_block()
            self.maxpool_2 = layers.MaxPooling2D((2,2))
            self.flat = layers.Flatten()
            self.classifier = layers.Dense(num_classes, activation="softmax")
            
    def forward_1(self, split_batch):
        with tf.device('/GPU:0'):
            x = self.conv_0(split_batch)
            x = self.block_1(x)
            ret = self.maxpool_1(x)
            return ret
    
    def forward_2(self, split_batch):
        with tf.device('/GPU:1'):
            x = self.block_2(split_batch)
            x = self.block_3(x)
            x = self.block_4(x)
            x = self.maxpool_2(x)
            x = self.flat(x)
            ret = self.classifier(x)
            return ret

    def call(self, inputs):
        with tf.device('/GPU:0'):
            splits = tf.split(inputs, self.splits, axis=0, num=self.splits, name="split_batch")
        pipe_1 = self.forward_1(splits[0])
        pipe_2 = self.forward_1(splits[1])
        pipe_1 = self.forward_2(pipe_1)
        pipe_2 = self.forward_2(pipe_2)
        with tf.device('/GPU:1'):
            ret = tf.concat([pipe_1, pipe_2], 0, name="concat_batch")
            return ret

In [None]:
model = PipelineCNN(splits=2)
opt = tf.keras.optimizers.Adam()
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")
model.compile(loss="categorical_crossentropy",
              optimizer=opt)

time_history = TimeHistory()

train_log = model.fit(train_dataset, steps_per_epoch=steps_per_epoch,
                      epochs=2, verbose=1,
                      callbacks=[time_history])

peak_fps = int(steps_per_epoch*BATCH_SIZE/min(time_history.times))

print("* Params:", model.count_params())
print("* Peak FPS:", peak_fps)

Peak FPS: 1975