In [1]:
import torch
import torch.nn as nn

class SandwichBatchNorm2d(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.bn = nn.BatchNorm2d(num_features, affine=True)
        self.embed = nn.Embedding(num_classes, num_features * 2)
        self.embed.weight.data[:, :num_features].normal_(1, 0.02)
        self.embed.weight.data[:, num_features:].zero_()

    def forward(self, x, y):
        out = self.bn(x)
        gamma, beta = self.embed(y).chunk(2, 1)
        out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
        return out

In [12]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class SandwichBatchNorm(layers.Layer):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.bn = layers.BatchNormalization()
        self.embed = layers.Embedding(num_classes, num_features * 2,keras.initializers.random_normal(1,0.02))
        self.embed.build(None)
        weight = self.embed.get_weights()
        weight[0][:,num_features:] = 0
        self.embed.set_weights(weight)

    def call(self, x, y):
        out = self.bn(x)
        gamma, beta = tf.split(self.embed(y),2, 1)
        out = tf.reshape(gamma,(-1, 1, 1, self.num_features)) * out + tf.reshape(beta,(-1, 1, 1, self.num_features))
        return out

In [13]:
class ConvBlock(keras.Model):
    def __init__(self,c,nc,k=3,s=1):
        super().__init__()
        self.conv1 = layers.Conv2D(c,k,s,padding='same',use_bias=False)
        self.bn1 = SandwichBatchNorm(c,nc)
    def call(self,x,y):
        x = self.conv1(x)
        x = self.bn1(x,y)
        return tf.nn.relu(x)            

class BottleNeck(keras.Model):
    def __init__(self, dim1, dim2,nc,strides=1):
        super().__init__()        
        self.conv1 = ConvBlock(dim1,nc,1)
        self.conv2 = ConvBlock(dim1,nc,3,strides)
        self.conv3 = ConvBlock(dim2,nc,1)
        self.downsample_conv = ConvBlock(dim2,nc,3,strides)

    def call(self, x,y):
        out = self.conv1(x,y)
        out = self.conv2(out,y)
        out = self.conv3(out,y)
        if x.shape[3] != out.shape[3]:
            x = self.downsample_conv(x,y)
        return tf.nn.relu(out + x)    
    
class ResidualBlock(keras.Model):
    def __init__(self, num_bottlenecks, dim1, dim2,nc, strides=1):
        super().__init__()
        self.bottlenecks = [BottleNeck(dim1, dim2,nc,strides=strides)]
        for idx in range(1, num_bottlenecks):
            self.bottlenecks.append(BottleNeck(dim1, dim2,nc))

    def call(self, x,y):
        for btn in self.bottlenecks:
            x = btn(x,y)
        return x

def make_ResNetwithSBN(num_classes):
    x_input =layers.Input((None,None,3))
    y = layers.Input((num_classes))
    x = ConvBlock(64,num_classes,7,2)(x_input,y)
    x = layers.MaxPool2D(pool_size=3, strides=2,padding='same')(x)
    x = ResidualBlock(num_bottlenecks=3, dim1=64, dim2=128,nc=num_classes,strides=1,)(x,y)
    x = ResidualBlock(num_bottlenecks=4, dim1=128, dim2=256,nc=num_classes, strides=2)(x,y)
    x = ResidualBlock(num_bottlenecks=6, dim1=256, dim2=512,nc=num_classes, strides=2)(x,y)
    x = ResidualBlock(num_bottlenecks=3, dim1=512, dim2=1024,nc=num_classes, strides=2)(x,y)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(num_classes)(x)
    return keras.Model([x_input,y],x)

In [29]:
num_classes = 10
model = make_ResNetwithSBN(num_classes)

In [30]:
image = tf.zeros((10,224,224,3))
model([image,tf.experimental.numpy.random.randint(0,num_classes,(image.shape[0]))])

<tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>