## Compare current implementation with TF implementation

### Minimal model implementation in TF
https://github.com/google-research/fixmatch/blob/08d9b83d7cc87e853e6afc5a86b12aacff56cdea/libml/models.py#L62

In [1]:
# !pip install tensorflow==1.14.0

In [2]:
import tensorflow as tf
tf.__version__

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


'1.14.0'

In [4]:
import functools


class ResNet:
    
    class Dataset: pass
    
    def __init__(self):
        self.dataset = self.Dataset()
        self.dataset.mean = 0.0
        self.dataset.std = 1.0    
        self.nclass = 10

    def classifier(self, x, scales=3, filters=32, repeat=4, training=False, getter=None, dropout=0, **kwargs):
        del kwargs
        leaky_relu = functools.partial(tf.nn.leaky_relu, alpha=0.1)
        bn_args = dict(training=training, momentum=0.999)

        def conv_args(k, f):
            return dict(padding='same',
                        kernel_initializer=tf.random_normal_initializer(stddev=tf.rsqrt(0.5 * k * k * f)))
#                         kernel_initializer=tf.constant_initializer(0.00123))

        def residual(x0, filters, stride=1, activate_before_residual=False):
            x = leaky_relu(tf.layers.batch_normalization(x0, **bn_args))
            if activate_before_residual:
                x0 = x

            x = tf.layers.conv2d(x, filters, 3, strides=stride, **conv_args(3, filters))
            x = leaky_relu(tf.layers.batch_normalization(x, **bn_args))
            x = tf.layers.conv2d(x, filters, 3, **conv_args(3, filters))

            if x0.get_shape()[3] != filters:
                x0 = tf.layers.conv2d(x0, filters, 1, strides=stride, **conv_args(1, filters))

            return x0 + x

        out = {}
        
        with tf.variable_scope('classify', reuse=tf.AUTO_REUSE, custom_getter=getter):
            out["c1"] = y = tf.layers.conv2d((x - self.dataset.mean) / self.dataset.std, 16, 3, **conv_args(3, 16))
            for scale in range(scales):
                out["res_{}".format(scale)] = y = residual(y, filters << scale, stride=2 if scale else 1, activate_before_residual=scale == 0)
                for i in range(repeat - 1):
                    out["res_{}_{}".format(scale, i)] = y = residual(y, filters << scale)

            embeds = y = leaky_relu(tf.layers.batch_normalization(y, **bn_args))
            y = tf.reduce_mean(y, [1, 2])
            if dropout and training:
                y = tf.nn.dropout(y, 1 - dropout)
            logits = tf.layers.dense(y, self.nclass, kernel_initializer=tf.glorot_normal_initializer())
#             logits = tf.layers.dense(y, self.nclass, kernel_initializer=tf.constant_initializer(0.01432))

        return logits, embeds, out


In [6]:
tf.set_random_seed(0)
tf_model = ResNet()

In [7]:
tf_x = tf.placeholder(tf.float32, [None] + [32, 32, 3], 'x')
tf_logits, tf_embeds, tf_out = tf_model.classifier(tf_x)

Instructions for updating:
Use `tf.keras.layers.Conv2D` instead.
Instructions for updating:
Use keras.layers.BatchNormalization instead.  In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.batch_normalization` documentation).












Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Use keras.layers.dense instead.


In [8]:
weights_dict = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
weights_dict = {v.name: v for v in weights_dict}

In [9]:
import numpy as np
np.random.seed(0)
np_x = np.random.rand(8, 32, 32, 3).astype("float32")


with tf.Session() as s:
    s.run(tf.global_variables_initializer())    
    res = s.run(
        [tf_logits, tf_embeds, tf_out],
        feed_dict={
            tf_x: np_x
        }
    )
    np_weights_dict = s.run(weights_dict)

In [10]:
res[0]

array([[-4.826606 , -9.9340925,  3.8899462, -2.6317124,  5.45304  ,
        -5.7924585,  4.752457 ,  1.5139014,  8.014108 , -3.4554937],
       [-4.890322 , -9.340751 ,  3.6937258, -2.844387 ,  5.346814 ,
        -6.0465345,  4.6167617,  1.7033981,  7.9195976, -3.6339433],
       [-5.194474 , -9.558173 ,  4.1587205, -2.3802965,  5.5766435,
        -6.065986 ,  4.8269773,  1.7820188,  8.294075 , -3.2630131],
       [-4.4036365, -9.257714 ,  3.7970524, -2.2597632,  5.687299 ,
        -5.9189053,  4.6535378,  1.601984 ,  7.9605904, -3.756865 ],
       [-4.565281 , -9.504535 ,  3.5524337, -2.6782646,  5.3547196,
        -5.421666 ,  5.0375605,  1.4634992,  8.065021 , -3.617325 ],
       [-4.792373 , -9.602575 ,  3.3932242, -2.8439019,  5.5520005,
        -6.3028493,  4.731669 ,  1.5132834,  8.200872 , -3.8719633],
       [-5.025964 , -9.571217 ,  4.203959 , -2.7074006,  5.699357 ,
        -6.0045753,  4.875746 ,  2.1217115,  8.213163 , -3.7344334],
       [-4.4215446, -9.851782 ,  3.827945

In [11]:
for k, v in np_weights_dict.items():
    print(k, v.shape)

classify/conv2d/kernel:0 (3, 3, 3, 16)
classify/conv2d/bias:0 (16,)
classify/batch_normalization/gamma:0 (16,)
classify/batch_normalization/beta:0 (16,)
classify/batch_normalization/moving_mean:0 (16,)
classify/batch_normalization/moving_variance:0 (16,)
classify/conv2d_1/kernel:0 (3, 3, 16, 32)
classify/conv2d_1/bias:0 (32,)
classify/batch_normalization_1/gamma:0 (32,)
classify/batch_normalization_1/beta:0 (32,)
classify/batch_normalization_1/moving_mean:0 (32,)
classify/batch_normalization_1/moving_variance:0 (32,)
classify/conv2d_2/kernel:0 (3, 3, 32, 32)
classify/conv2d_2/bias:0 (32,)
classify/conv2d_3/kernel:0 (1, 1, 16, 32)
classify/conv2d_3/bias:0 (32,)
classify/batch_normalization_2/gamma:0 (32,)
classify/batch_normalization_2/beta:0 (32,)
classify/batch_normalization_2/moving_mean:0 (32,)
classify/batch_normalization_2/moving_variance:0 (32,)
classify/conv2d_4/kernel:0 (3, 3, 32, 32)
classify/conv2d_4/bias:0 (32,)
classify/batch_normalization_3/gamma:0 (32,)
classify/batch_nor

### PyTorch implementation

In [13]:
import sys
sys.path.insert(0, "..")

In [14]:
import torch

from wrn import WideResNet

In [15]:
model = WideResNet(num_classes=10)

In [16]:
pt_weights_dict = {}
for n, p in model.named_parameters():    
    print(n, p.shape)
    pt_weights_dict[n] = p

conv.weight torch.Size([16, 3, 3, 3])
conv.bias torch.Size([16])
res_blocks.0.0.bn1.weight torch.Size([16])
res_blocks.0.0.bn1.bias torch.Size([16])
res_blocks.0.0.conv1.weight torch.Size([32, 16, 3, 3])
res_blocks.0.0.conv1.bias torch.Size([32])
res_blocks.0.0.bn2.weight torch.Size([32])
res_blocks.0.0.bn2.bias torch.Size([32])
res_blocks.0.0.conv2.weight torch.Size([32, 32, 3, 3])
res_blocks.0.0.conv2.bias torch.Size([32])
res_blocks.0.0.skip.weight torch.Size([32, 16, 1, 1])
res_blocks.0.0.skip.bias torch.Size([32])
res_blocks.0.1.bn1.weight torch.Size([32])
res_blocks.0.1.bn1.bias torch.Size([32])
res_blocks.0.1.conv1.weight torch.Size([32, 32, 3, 3])
res_blocks.0.1.conv1.bias torch.Size([32])
res_blocks.0.1.bn2.weight torch.Size([32])
res_blocks.0.1.bn2.bias torch.Size([32])
res_blocks.0.1.conv2.weight torch.Size([32, 32, 3, 3])
res_blocks.0.1.conv2.bias torch.Size([32])
res_blocks.0.2.bn1.weight torch.Size([32])
res_blocks.0.2.bn1.bias torch.Size([32])
res_blocks.0.2.conv1.weight

### Setup weights

In [17]:
pt_weights_names = list(pt_weights_dict.keys())

c = 0
for tf_name in np_weights_dict:
    
    if "moving_" in tf_name:
        continue

    pt_name = pt_weights_names[c]
    c += 1
        
    w = np_weights_dict[tf_name]
    if w.ndim == 4:
        w = w.transpose((-1, -2, 0, 1))
    elif w.ndim == 2:
        w = w.transpose((1, 0))

    pt_weights_dict[pt_name].data = torch.from_numpy(w)

### Check forward pass

In [19]:
x = torch.from_numpy(np_x.transpose(0, 3, 1, 2))
x.shape

torch.Size([8, 3, 32, 32])

C1

In [20]:
model.eval()
with torch.no_grad():
    y = model.conv(x)

y.shape

torch.Size([8, 16, 32, 32])

In [21]:
np.abs(y.detach().numpy() - res[2]['c1'].transpose(0, 3, 1, 2)).mean()

0.0

C1 + res_0

In [26]:
model.eval()
with torch.no_grad():
    y = model.conv(x)
    y = model.res_blocks[0][0](y)

y.shape

torch.Size([8, 32, 32, 32])

In [27]:
np.abs(y.detach().numpy() - res[2]['res_0'].transpose(0, 3, 1, 2)).mean()

0.0002001329

C1 + res_0 + res_0_0

In [30]:
model.eval()
with torch.no_grad():
    y = model.conv(x)
    y = model.res_blocks[0][0](y)
    y = model.res_blocks[0][1](y)    

y.shape

torch.Size([8, 32, 32, 32])

In [32]:
np.abs(y.detach().numpy() - res[2]['res_0_0'].transpose(0, 3, 1, 2)).mean()

0.000370559

C1 + res_0 + res_0_0 + res_0_1

In [37]:
model.eval()
with torch.no_grad():
    y = model.conv(x)
    y = model.res_blocks[0][0](y)
    y = model.res_blocks[0][1](y)    
    y = model.res_blocks[0][2](y)

y.shape

torch.Size([8, 32, 32, 32])

In [38]:
np.abs(y.detach().numpy() - res[2]['res_0_1'].transpose(0, 3, 1, 2)).mean()

0.0008550567

In [29]:
tt = res[1].transpose(0, 3, 1, 2)
tt[0, 0]

array([[-0.04882129, -0.6541816 , -0.68975157, -0.7064169 , -0.12633513,
        -0.1601686 , -0.09963267, -0.15335017],
       [-0.40544873, -0.29413065, -0.5399385 , -0.15704231, -0.5762207 ,
        -0.29151377,  4.4074793 ,  1.6000389 ],
       [-0.3569432 ,  0.8366267 , -0.08593757, -0.3662109 , -0.55905956,
        -0.03731953,  0.5960971 , -0.20261107],
       [-0.1551569 ,  0.55575985,  0.27649838, -0.5633737 , -0.6681551 ,
        -0.10593164,  0.49091116, -0.7166713 ],
       [-0.17578112, -0.87082785, -0.44707137, -0.62656134, -0.01327674,
        -0.60499626, -0.28430307, -0.4966391 ],
       [-0.12249473, -0.28263515, -0.4631001 , -0.11124589, -0.16990219,
        -0.41853818, -0.77722937, -0.27153394],
       [ 0.5397306 , -0.05010793,  2.3187895 , -0.1615778 , -0.18258236,
        -0.37532791, -0.19256835, -0.91219485],
       [ 1.3061581 ,  4.1853523 ,  2.4442306 ,  5.1892548 ,  4.3667893 ,
         0.07944617,  4.0131106 , -0.0449759 ]], dtype=float32)

In [30]:
embeds[0, 0]

tensor([[ 0.0329, -0.8374, -0.6899, -0.7614, -0.1965, -0.2086, -0.0704, -0.2827],
        [-0.3057, -0.8483, -0.5455, -0.7675, -1.1223, -0.6844,  2.7685,  0.2753],
        [-0.1606, -0.2199, -0.1464, -0.4656, -0.4709, -0.0835,  3.8320, -0.0853],
        [-0.2396, -0.3816,  2.2154, -0.3218, -0.1382, -0.2471, -0.2616, -0.3664],
        [-0.3701, -0.7153, -0.0528, -0.4299, -0.4102, -0.3094, -0.3719, -0.3120],
        [-0.2454, -0.3176, -0.5936,  0.0936, -0.2541, -0.3764, -0.0661, -0.6242],
        [ 0.2177, -0.4032,  0.1267, -0.4160, -0.6692, -0.3563, -0.2826, -0.8369],
        [ 2.1535,  3.2108,  5.1179,  6.1287,  2.0405,  1.2819,  0.0713, -0.2172]])

In [25]:
model.eval()
with torch.no_grad():
    y = model(x)

y.shape

torch.Size([8, 10])

In [26]:
y[0, :]

tensor([-5.1307, -9.5296,  4.1101, -2.7820,  6.1873, -5.7317,  4.8811,  1.3497,
         8.0847, -3.2070])

In [27]:
res[0][0, :]

array([-4.826606 , -9.9340925,  3.8899462, -2.6317124,  5.45304  ,
       -5.7924585,  4.752457 ,  1.5139014,  8.014108 , -3.4554937],
      dtype=float32)

In [28]:
np.abs(y.detach().numpy() - res[0]).mean()

0.2863981