In [1]:
#loading .dll files for GPU b/c for some reason TF isn't finding them through PATH (Windows 10)
#    -assumes that Cuda Toolkit v10.0.130 is installed in 'C:\Program Files\NVIDIA GPU Computing Toolkit\'
#    -assumes that nvcuda.dll is in 'C:\System32\'
import ctypes
hllDll = ctypes.WinDLL("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.0\\bin\\cudart64_100.dll")
tllDll = ctypes.WinDLL("C:\\tools\\cuda\\bin\\cudnn64_7.dll")
cllDll = ctypes.WinDLL("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.0\\bin\\cublas64_100.dll")
ullDll = ctypes.WinDLL("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.0\\extras\\CUPTI\\libx64\\cupti64_100.dll")
nllDll = ctypes.WinDLL("C:\\Windows\\System32\\nvcuda.dll")

#all necessary imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model,Input
from tensorflow.keras.layers import Conv2D,MaxPool2D,Flatten,Dense,Layer,InputLayer
from tensorflow.keras.activations import softmax
import tensorflow.keras.backend as K
import os
import numpy as np


#tensorboard  in Jupyter Notebook
#%load_ext tensorboard

#set image formatting for TF
K.set_image_data_format('channels_last')

#to prevent memory allocation fails on GPU
#    -assumes one GPU in system being used for training
tf.config.experimental.set_memory_growth(tf.config.experimental.list_physical_devices('GPU')[0],True)

K.clear_session()

In [2]:
class Expert(Layer):
    """Multiple Conv2D layers wrapped into a single "expert" layer."""
    
    def __init__(self,num,**kwargs):
        """
        Initialize an Expert layer
        
        num  -- the assigned "number" for the expert, can be from 0 to num_experts - 1
        """
        
        ### initialize layer ###
        super(Expert,self).__init__(**kwargs)
        
        #grabs the corresponding input from the gate layer
        self.split_1 = Splitter(0)
        self.split_2 = Splitter(num)
        
        #applies convolutions according to https://arxiv.org/abs/1501.00092
        self.conv_3 = optConv2D(64,9,strides=(1,1),padding='valid',activation='relu',dynamic=True)
        self.conv_4 = optConv2D(32,5,strides=(1,1),padding='valid',activation='relu',dynamic=True)
        self.conv_5 = optConv2D(3,5,strides=(1,1),padding='valid',activation='relu',dynamic=True)
    
    def call(self,inputs):
        out = self.split_1(inputs)
        out = self.split_2(out)
        out = self.conv_3(out)
        out = self.conv_4(out)
        out = self.conv_5(out)
        return out
    
    def compute_output_shape(self, input_shape):
        return input_shape

In [9]:
class GateLayer(Layer):
    """A GateLayer takes in the original input image and the weights from the Gater.
    Input shape: [(16,128,128,3),(16,num_experts)]
    It outputs the tensors that the experts will receive along with the weigts from the Gater
    """
    def __init__(self,num_experts,k,batch_size,**kwargs):
        super(GateLayer, self).__init__(**kwargs)
        self.num_experts = num_experts
        self.k = k
        self.batch_size = batch_size
    
    def build(self,input_shape):
        super(GateLayer,self).build(input_shape)
        
        #a trainable tensor for the input weights
        self.ws = self.add_weight(name='ws',
                                  dtype=tf.float32,
                                  shape=input_shape[1],
                                  initializer=tf.initializers.zeros(),
                                  trainable=True)
         
        #a trainable tensor for noise
        self.noise_weight = self.add_weight(name='noise_weight',
                                            dtype=tf.float32,
                                            shape=input_shape[1],
                                            initializer=tf.initializers.zeros(),
                                            trainable=True)
    
    def call(self,inputs):
        """
        Inputs are in shape [(None,128,128,3),(none,num_experts)].
        Processes the weights from the Gater with noisy top_k gating.
        Outputs either the original image input or a null tensor to the experts
        Outputs the processed weights to the Combine layer
        """
        
        ### Processes the weights ###
        
        # see https://arxiv.org/pdf/1701.06538.pdf
        weights = inputs[1]
        noisy_top_k = (weights*self.ws) + (K.random_normal(weights.shape)*K.softplus(weights*self.noise_weight))
        #noisy_top_k = weights
        #creates a binary mask to remove non top_k weights
        indices_to_remove = noisy_top_k < tf.math.top_k(noisy_top_k,self.k,sorted=True)[0][..., -1, None]
        mask = tf.where(indices_to_remove == False, x=1, y=0)
        mask = tf.dtypes.cast(mask,tf.float32)
        
        #sets all non top_k values to -inf
        output_weights = noisy_top_k * mask
        output_weights = tf.where(output_weights==0,x=-np.inf,y=output_weights)
        
        
        #applies softmax function
        output_weights = K.softmax(output_weights)
        
        ### Processes the inputs for the experts ###
        
        # depending on the weights, either passes the original input on to the experts or passes a null tensor (all -1)
        output_inputs = []
        
        input_passthrough = inputs[0]
        input_shape = input_passthrough.shape
        null = tf.zeros_like(input_passthrough)-1
        
        #mask is in shape (16,4)
        #input_passthrough and null in shape (16,128,128,3)
        #want output of [(16,128,128,3) * num_experts], each (1,128,128,3) image is either an input_passthrough or null
        counter = 0 #from 0 to 15
        unstacked = tf.unstack(mask) #[(1,4) * 16]
        for stack in unstacked: # each stack is (1,4)
            singles = tf.split(stack,self.num_experts) #[(1),(1),(1),(1)]
            for single_val in singles: #[[(1),(1),(1),(1)] * 16]
                #replaces each 1 or 0 with the corresponding input or null tensor of shape (1,128,128,3)
                if K.all(tf.math.equal(single_val,1)): #if expert is in top_k
                    output_inputs.append(tf.gather(input_passthrough,counter))
                else: #if not
                    output_inputs.append(tf.gather(null,counter))
            counter += 1
        
        #left with output_inputs as a list in shape [(1,128,128,3) * 64]
        output_inputs = tf.squeeze(output_inputs)
        output_inputs = tf.stack(output_inputs,axis=0) #(64,128,128,3)
        output_inputs = tf.split(output_inputs,self.batch_size) #[(4,128,128,3) * 16]
        
        #grabs the corresponding (1,128,128,3) images from each 16 tensors and stacks them
        output_inputs = [tf.stack([tf.gather(output_inputs[i],j) for i in range(self.batch_size)]) for j in range(self.num_experts)]
        
        #we are left with output_inputs as a list of shape [(16,128,128,3) * num_experts]
        
        self.outputs = [output_inputs,mask,output_weights]
        return self.outputs
    
    def compute_output_shape(self, input_shape):
        #see call()
        baseshape = [[]]
        shape = input_shape[0]
        for i in range(self.num_experts):
            baseshape[0].append(shape)
        baseshape.append(input_shape[1])
        baseshape.append(input_shape[1])
        return baseshape

In [4]:
class Gater(Layer):
    """Takes in the original image input and generates weights for each expert"""
    
    def __init__(self,num_experts,**kwargs):
        super(Gater,self).__init__(**kwargs)
        
        self.num_experts = num_experts
        self.conv_1 = Conv2D(4,4,2,activation='relu')
        self.conv_2 = Conv2D(4,4,2,activation='relu')
        self.flatten_3 = Flatten()
        self.dense_4 = Dense(num_experts,activation='relu',name='gater')
        self.dense_5 = Dense(num_experts,activation='softmax',name='gater',activity_regularizer=tf.keras.regularizers.l2(0.001))
    
    def call(self,inputs):
        out = self.conv_1(inputs)
        out = self.conv_2(out)
        out = self.flatten_3(out)
        out = self.dense_4(out)
        out = self.dense_5(out)
        return out
    
    def compute_output_shape(self, input_shape):
        #essentially (batch_size,num_experts)
        return (input_shape[0],self.num_experts)

In [5]:
class Splitter(Layer):
    """Grabs a the corresponding input from gatelayer based on the number of the expert"""
    
    def __init__(self,num,**kwargs):
        super(Splitter,self).__init__(**kwargs)
        self.num = num
        
    def build(self,input_shape):
        super(Splitter,self).build(input_shape)
    
    def call(self,inputs):
        return inputs[self.num]
    
    def compute_output_shape(self,input_shape):
        return input_shape[self.num]

In [6]:
class Combine(Layer):
    """Combines the expert outputs based on the processed weights from gatelayer"""
    
    def __init__(self,num_experts,top_k,batch_size,**kwargs):
        super(Combine,self).__init__(**kwargs)
        self.num_experts = num_experts
        self.top_k = top_k
        self.batch_size = batch_size
    
    def build(self, input_shape):
        super(Combine,self).build(input_shape)
    
    def call(self,inputs):
        experts = inputs[0]
        weights = inputs[-1]
        weighted_experts = []
        counter = 0
        for expert in experts:
            weight_col = tf.reshape(tf.stack([[tf.gather(weights[j],counter) for j in range(self.batch_size)]]),(self.batch_size,1))
            while len(weight_col.shape) < len(expert.shape):
                weight_col = tf.expand_dims(weight_col,-1)
            weighted_experts.append(tf.multiply(expert,weight_col))
            counter += 1
        return tf.reduce_sum(weighted_experts,axis=0)
    
    def compute_output_shape(self, input_shape):
        return input_shape[0][0]

In [7]:
class optConv2D(Conv2D):
    @tf.function
    def call(self,inputs):
        null = tf.zeros(tf.shape(inputs))
        if K.all(tf.math.equal(inputs,null)):
            null = tf.zeros(self.compute_output_shape(inputs.shape)) - 1
            return null
        else:
            return super(optConv2D,self).call(inputs)
    
    def compute_output_shape(self, input_shape):
        return super(optConv2D,self).compute_output_shape(input_shape)

In [8]:
class MOE(Model):
    def __init__(self,num_experts,top_k,batch_size,train_data_dir,**kwargs):
        super(MOE,self).__init__(**kwargs)
        self.train_data_dir = train_data_dir
        
        self.num_experts = num_experts
        self.top_k = top_k
        self.batch_size = batch_size
        
        self.inputlayer = InputLayer(input_shape=(32,32,3))

        #gater
        self.gater = Gater(self.num_experts,input_shape=(32,32,3),name='gater')

        self.gate = GateLayer(num_experts=self.num_experts,k=self.top_k,batch_size=self.batch_size,dynamic=True)#([self.inputlayer,self.gater.output])

        self.experts = [Expert(i) for i in range(self.num_experts)]
        
        self.get_weights = Splitter(-1)#(self.gate)
        self.get_mask = Splitter(-2)

        self.combine = Combine(num_experts=self.num_experts,top_k=self.top_k,batch_size=self.batch_size,dynamic=True)#([self.experts_out,self.weights_out])
    
    def call(self,inputs):
        inputs = self.inputlayer(inputs)
        out_gater = self.gater(inputs)
        out_gate = self.gate([inputs,out_gater])
        
        experts = [expert(out_gate) for expert in self.experts]
        weights = self.get_weights(out_gate)
        mask = self.get_mask(out_gate)
        
        output = self.combine([experts,weights])
        return output,mask
    
    def compute_output_shape(self, input_shape):
        return input_shape,tf.TensorShape([input_shape[0],self.num_experts])