In [1]:
import argparse
import os
from pathlib import Path
import torch
import torch.nn as nn
from PIL import Image
from os.path import basename
from os.path import splitext
from torchvision import transforms
from torchvision.utils import save_image

%load_ext autoreload
%autoreload 2
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as T

import tensorflow.keras as keras
import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

from operator import itemgetter
def print_public_attributes(obj):
    print(list([att for att in dir(obj) 
         if '__' not in att and not att.startswith('_')]))

In [2]:
def calc_mean_std_pt(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2] # N,C,H,W
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std

def calc_mean_std_tf(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
#     feat_mean, feat_var = tf.nn.moments(feat,[2,3],keepdims=True,)
    shape = tf.shape(feat)
    bessel_correction = shape[2] * shape[3] / (shape[2] * shape[3] - 1)
    feat_flat = tf.reshape(feat,[shape[0],shape[1],shape[2] * shape[3]])
    feat_mean = tf.reduce_mean(feat_flat,[2],keepdims=True)
    feat_var = tf.math.reduce_variance(feat_flat,[2],keepdims=True) * bessel_correction + eps
    feat_std = tf.sqrt(feat_var)
#     feat_ += eps
#     feat_std = tf.sqrt(feat_var)
    return feat_mean, feat_std

with torch.no_grad():
    inval = np.arange(4).reshape(1,1,2,2) / 1.
    exp = calc_mean_std_pt(
        torch.tensor(inval)
    )
    act = calc_mean_std_tf(
        tf.constant(inval)
    )
    for a,b in zip(exp,act):
        assert np.allclose(a.numpy(),b.numpy())

2022-02-28 08:57:01.823225: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


In [3]:
tf.math.reduce_variance(inval), torch.var(torch.tensor(inval),unbiased=False)

(<tf.Tensor: shape=(), dtype=float64, numpy=1.25>,
 tensor(1.2500, dtype=torch.float64))

In [4]:
exp, act

((tensor([[[[1.5000]]]], dtype=torch.float64),
  tensor([[[[1.2910]]]], dtype=torch.float64)),
 (<tf.Tensor: shape=(1, 1, 1), dtype=float64, numpy=array([[[1.5]]])>,
  <tf.Tensor: shape=(1, 1, 1), dtype=float64, numpy=array([[[1.29099832]]])>))

In [5]:
def mean_variance_norm_pt(feat):
    size = feat.size()
    mean, std = calc_mean_std_pt(feat)
    normalized_feat = (feat - mean.expand(size)) / std.expand(size)
    return normalized_feat

def mean_variance_norm_tf(feat):
    shape = tf.shape(feat)
    mean, std = calc_mean_std_tf(feat)
    normalized_feat = (feat - tf.broadcast_to(mean,shape)) \
        / tf.broadcast_to(std,shape)
    return normalized_feat


with torch.no_grad():
    inval = np.arange(4).reshape(1,1,2,2) / 1.
    exp = mean_variance_norm_pt(
        torch.tensor(inval)
    )
    act = mean_variance_norm_tf(
        tf.constant(inval)
    )
    assert np.allclose(exp.numpy(),act.numpy())

In [6]:
calc_mean_std_tf(inval)

(<tf.Tensor: shape=(1, 1, 1), dtype=float64, numpy=array([[[1.5]]])>,
 <tf.Tensor: shape=(1, 1, 1), dtype=float64, numpy=array([[[1.29099832]]])>)

In [9]:
from tensorflow.keras.applications import efficientnet

In [21]:
enet = efficientnet.EfficientNetB0(False,input_shape=(512,512,3))
enet.summary()

Model: "efficientnetb0"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
rescaling_2 (Rescaling)         (None, 512, 512, 3)  0           input_4[0][0]                    
__________________________________________________________________________________________________
normalization_2 (Normalization) (None, 512, 512, 3)  7           rescaling_2[0][0]                
__________________________________________________________________________________________________
stem_conv_pad (ZeroPadding2D)   (None, 513, 513, 3)  0           normalization_2[0][0]            
_____________________________________________________________________________________

In [22]:
vggnet = tf.keras.applications.vgg19.VGG19(False, input_shape=(512,512,3))
vggnet.summary()

Model: "vgg19"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(None, 512, 512, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 512, 512, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 512, 512, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 256, 256, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 256, 256, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 256, 256, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 128, 128, 128)     0     

In [19]:
%%timeit
vggnet(tf.random.normal((1,512,512,3))).shape

1.76 s ± 27.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [20]:
%%timeit
enet(tf.random.normal((1,512,512,3))).shape

312 ms ± 13.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [58]:
dec_arch = [
    ['refpad',[1,1,1,1]],
    ['conv',[512,256,(3,3)]],
    ['relu'],
    ['upsample',2],
    
    ['refpad',[1,1,1,1]],
    ['conv',[256,256,(3,3)]],
    ['relu'],
    ['refpad',[1,1,1,1]],
    ['conv',[256,256,(3,3)]],
    ['relu'],
    ['refpad',[1,1,1,1]],
    ['conv',[256,256,(3,3)]],
    ['relu'],
    ['refpad',[1,1,1,1]],
    ['conv',[256,128,(3,3)]],
    ['relu'],    
    ['upsample',2],
    
    
    ['refpad',[1,1,1,1]],
    ['conv',[128,128,(3,3)]],
    ['relu'],    
    ['refpad',[1,1,1,1]],
    ['conv',[128,64,(3,3)]],
    ['relu'],    
    ['upsample',2],
    
    ['refpad',[1,1,1,1]],
    ['conv',[64,64,(3,3)]],
    ['relu'],    
    ['refpad',[1,1,1,1]],
    ['conv',[64,3,(3,3)]],
]

def make_encoder():
    enet = efficientnet.EfficientNetB0(False,input_shape=(None, None,3))
#     inputs = tf.keras.Input(shape=(None,None,3))
#     output = enet(inputs)
    output_1 = enet.get_layer('block5a_activation').output
    output_2 = enet.get_layer('block6a_activation').output
    encoder = tf.keras.Model(inputs=enet.inputs,
                             outputs={'enc_a':output_1,'enc_b':output_2})
    encoder.trainable = False
    assert encoder(tf.random.normal((2,512,512,3))) is not None
    return encoder

def make_decoder():
    inputs = tf.keras.Input(shape=(None,None,512))
    
    outputs = inputs
    for layer_info in dec_arch:
        if layer_info[0] == 'relu':
            outputs = tf.keras.layers.ReLU()(outputs)
        elif layer_info[0] == 'conv':
            _,[inc,outc,[kw,kh]] = layer_info
            outputs = tf.keras.layers.Conv2D(outc,kernel_size=(kw,kh))(outputs)
        elif layer_info[0] == 'refpad':
            _,pad_size = layer_info
            # outputs = [NB, W, H, C] -> Pads = [4, 2]
            outputs = tf.pad(outputs,  paddings=[[0,0],[1,1],[1,1],[0,0]], mode='REFLECT', )
        elif layer_info[0] == 'upsample':
            _,factor = layer_info
            outputs = tf.keras.layers.UpSampling2D(size=(factor,factor))(outputs)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    
    assert model(tf.random.normal((1,16,16,512))) is not None
    return model
encoder = make_encoder()
decoder = make_decoder()

In [59]:
%%timeit
encoder(tf.random.normal((1,512,512,3))).items()

246 ms ± 2.48 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [61]:
%%timeit
decoder(tf.random.normal((2,16,16,512))).shape

191 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
