In [1]:
%reload_ext watermark
%reload_ext autoreload
%autoreload 2
%watermark -p numpy,sklearn,pandas
%watermark -p ipywidgets,cv2,PIL,matplotlib,plotly
%watermark -p torch,torchvision,torchaudio
%watermark -p tensorflow,tensorboard,tflite
%watermark -p onnx,onnxruntime,tensorrt,tvm
%matplotlib inline
%config InlineBackend.figure_format='retina'
%config IPCompleter.use_jedi = False

from IPython.display import display, Markdown, HTML, Image, Javascript
from IPython.core.magic import register_line_cell_magic, register_line_magic, register_cell_magic
display(HTML('<style>.container { width:%d%% !important; }</style>' % 95))

import sys, os, io, time, random, math
import json, base64, requests, shutil
import os.path as osp
import numpy as np

def _IMPORT(x):
    try:
        x = x.strip()
        if x.startswith('https://'):
            x = x[8:]
        if not x.endswith('.py'):
            x = x + '.py'
        if x[0] == '/':
            with open(x) as fr:
                x = fr.read()
        else:
            x = x.replace('blob/main/', '').replace('blob/master/', '')
            if x.startswith('raw.githubusercontent.com'):
                uri = 'https://' + x
                x = requests.get(uri)
                if x.status_code == 200:
                    x = x.text
            elif x.startswith('github.com'):
                uri = x.replace('github.com', 'raw.githubusercontent.com')
                mod = uri.split('/')
                for s in ['main', 'master']:
                    uri = 'https://' + '/'.join(mod[:3]) + s + '/'.join(mod[-3:])
                    x = requests.get(uri)
                    if x.status_code == 200:
                        x = x.text
                        break
            elif x.startswith('gitee.com'):
                mod = x.split('/')
                for s in ['/raw/main/', '/raw/master/']:
                    uri = 'https://' + '/'.join(mod[:3]) + s + '/'.join(mod[3:])
                    x = requests.get(uri)
                    if x.status_code == 200:
                        x = x.text
                        break
        exec(x, globals())
    except:
        pass

def _DIR(x, dumps=True, ret=True):
    attrs = sorted([y for y in dir(x) if not y.startswith('_')])
    result = '%s: %s' % (str(type(x))[8:-2], json.dumps(attrs) if dumps else attrs)
    if ret:
        return result
    print(result)



numpy 1.19.5
sklearn 0.0
pandas 1.1.5
ipywidgets 7.6.3
cv2 4.5.3
PIL 8.3.1
matplotlib 3.3.4
plotly 5.3.0
torch 1.8.1+cu101
torchvision 0.9.1+cu101
torchaudio not installed
tensorflow 2.6.0
tensorboard 2.6.0
tflite 2.4.0
onnx 1.10.1
onnxruntime 1.8.1
tensorrt not installed
tvm not installed


In [85]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
import numpy as np

In [60]:
_IMPORT('gitee.com/qrsforever/blog_source_codes/AI/tensorflow/misc/ResnetPeriodEstimator')

In [75]:
BATCH_SIZE = 1
IMAGE_W = 112
IMAGE_H = 112
NUM_FRAMES = 64
SIM_TEMPERATURE = 13.544

SAVED_MODEL_ROOT = '/data/nb_data/split_repnet_model'

## Download and Restore Model Weights

In [5]:
PATH_TO_CKPT = '/data/pretrained/cv/repnet/'

!test ! -d $PATH_TO_CKPT && mkdir -p $PATH_TO_CKPT && \
    wget -P $PATH_TO_CKPT https://storage.googleapis.com/repnet_ckpt/checkpoint && \
    wget -P $PATH_TO_CKPT https://storage.googleapis.com/repnet_ckpt/ckpt-88.data-00000-of-00002 && \
    wget -P $PATH_TO_CKPT https://storage.googleapis.com/repnet_ckpt/ckpt-88.data-00001-of-00002 && \
    wget -P $PATH_TO_CKPT https://storage.googleapis.com/repnet_ckpt/ckpt-88.index

In [6]:
model = ResnetPeriodEstimator()
model.call = tf.function(model.call)
tf.train.Checkpoint(model=model).restore(f'{PATH_TO_CKPT}/ckpt-88').expect_partial()
test_inputs = np.random.randn(BATCH_SIZE, NUM_FRAMES, IMAGE_H, IMAGE_W, 3).astype(np.float32)
test_inputs_tensor = tf.convert_to_tensor(test_inputs)

In [7]:
test_outputs = model(test_inputs_tensor)
test_outputs[0].shape, test_outputs[1].shape, test_outputs[2].shape

(TensorShape([1, 64, 32]), TensorShape([1, 64, 1]), TensorShape([1, 64, 512]))

In [8]:
model.summary()

Model: "resnet_period_estimator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
model (Functional)           (None, None, None, 1024)  5209600   
_________________________________________________________________
conv3d (Conv3D)              multiple                  14156288  
_________________________________________________________________
batch_normalization (BatchNo multiple                  2048      
_________________________________________________________________
conv2d (Conv2D)              multiple                  320       
_________________________________________________________________
dense (Dense)                multiple                  1049088   
_________________________________________________________________
dense_1 (Dense)              multiple                  1049088   
_________________________________________________________________
transformer_layer (Transform multiple      

In [9]:
_DIR(model)

'__main__.ResnetPeriodEstimator: ["activity_regularizer", "add_loss", "add_metric", "add_update", "add_variable", "add_weight", "apply", "base_model", "base_model_layer_name", "build", "built", "call", "compile", "compiled_loss", "compiled_metrics", "compute_dtype", "compute_mask", "compute_output_shape", "compute_output_signature", "conv_3x3_layer", "conv_channels", "conv_kernel_size", "count_params", "distribute_strategy", "dropout_layer", "dropout_rate", "dtype", "dtype_policy", "dynamic", "evaluate", "evaluate_generator", "fc_layers", "finalize_state", "fit", "fit_generator", "from_config", "get_config", "get_input_at", "get_input_mask_at", "get_input_shape_at", "get_layer", "get_losses_for", "get_output_at", "get_output_mask_at", "get_output_shape_at", "get_updates_for", "get_weights", "history", "image_size", "inbound_nodes", "input", "input_mask", "input_names", "input_projection", "input_projection2", "input_shape", "input_spec", "inputs", "l2_reg_weight", "layers", "load_weigh

## Split Model

### Model Part-1

In [37]:
class ResnetPart1(tf.keras.models.Model):
    def __init__(self, model):
        super().__init__(name='ResnetPart1')
        self.base_model = tf.keras.models.Model(
            name='base_model', inputs=model.base_model.input,
            outputs=model.base_model.get_layer('conv4_block3_out').output)
    
    def call(self, x):
        x = tf.reshape(x, (-1, IMAGE_H, IMAGE_W, 3))
        return self.base_model(x)

In [86]:
model1 = ResnetPart1(model)
model_part1_outputs = model1(test_inputs_tensor)
tf.reshape(model_part1_outputs, (-1,))[:3], '-'*90, model_part1_outputs.shape

(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 0.76912963, -0.29019433, -0.14487836], dtype=float32)>,
 '------------------------------------------------------------------------------------------',
 TensorShape([64, 7, 7, 1024]))

In [39]:
model1.summary(line_length=90)

Model: "ResnetPart1"
__________________________________________________________________________________________
Layer (type)                            Output Shape                        Param #       
base_model (Functional)                 (None, None, None, 1024)            5209600       
Total params: 5,209,600
Trainable params: 5,193,856
Non-trainable params: 15,744
__________________________________________________________________________________________


In [None]:
model1.save(filepath=f'{SAVED_MODEL_ROOT}/model1', save_format='tf')

### Model Part-2

In [57]:
class ResnetPart2(tf.keras.models.Model):
    def __init__(self, model):
        super().__init__(name='ResnetPart2')
        self.temporal_conv_layer = layers.Conv3D(
            512, 3, padding='same', name='temporal_conv_layer',
            dilation_rate=(3, 1, 1), weights=model.temporal_conv_layers[0].get_weights())
        self.temporal_bn_layers = layers.BatchNormalization(
            name='temporal_bn_layers', weights=model.temporal_bn_layers[0].get_weights())                  
        
    def call(self, x):
        x = tf.reshape(x, [BATCH_SIZE, -1] + x.shape.as_list()[1:])
        x = self.temporal_conv_layer(x)
        x = self.temporal_bn_layers(x)
        x = tf.nn.relu(x)
        return tf.reduce_max(x, [2, 3])

In [58]:
model2 = ResnetPart2(model)
model_part2_outputs = model2(model_part1_outputs)
tf.reshape(model_part2_outputs, (-1,))[:3], '-'*60, tf.reshape(test_outputs[2], (-1,))[:3]

(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.       , 1.0784211, 0.7234727], dtype=float32)>,
 '------------------------------------------------------------',
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.       , 1.0784211, 0.7234727], dtype=float32)>)

In [59]:
model2.summary(line_length=90)

Model: "ResnetPart2"
__________________________________________________________________________________________
Layer (type)                            Output Shape                        Param #       
temporal_conv_layer (Conv3D)            multiple                            14156288      
__________________________________________________________________________________________
temporal_bn_layers (BatchNormalization) multiple                            2048          
Total params: 14,158,336
Trainable params: 14,157,312
Non-trainable params: 1,024
__________________________________________________________________________________________


In [77]:
model2.save(filepath=f'{SAVED_MODEL_ROOT}/model2', save_format='tf')

INFO:tensorflow:Assets written to: /data/nb_data/split_repnet_model/model2/assets


### Model Part-3

In [69]:
class ResnetPart3(tf.keras.models.Model):
    def __init__(self, model):
        super().__init__(name='ResnetPart3')
        self.conv_3x3_layer = layers.Conv2D(
            32, 3, padding='same', activation=tf.nn.relu, name='conv_3x3_layer',
            weights=model.conv_3x3_layer.get_weights())
        
        self.input_projection1 = layers.Dense(
            512, activation=None, name='input_projection1',
            weights=model.input_projection.get_weights())
        self.input_projection2 = layers.Dense(
            512, activation=None, name='input_projection2',
            weights=model.input_projection2.get_weights())
        
        self.pos_encoding1 = tf.compat.v1.get_variable(name='pos_encoding',  initializer=model.pos_encoding.numpy())
        self.pos_encoding2 = tf.compat.v1.get_variable(name='pos_encoding2', initializer=model.pos_encoding2.numpy())
        
        self.transformer_layer1 = TransformerLayer(512, 4, 512, 0.0, True, name='transformer_layer1')
        self.transformer_layer1(tf.random.uniform((BATCH_SIZE, NUM_FRAMES, 512)))
        self.transformer_layer1.set_weights(model.transformer_layers[0].get_weights())
        
        self.transformer_layer2 = TransformerLayer(512, 4, 512, 0.0, True, name='transformer_layer2')
        self.transformer_layer2(tf.random.uniform((BATCH_SIZE, NUM_FRAMES, 512)))
        self.transformer_layer2.set_weights(model.transformer_layers2[0].get_weights())
        
        self.dropout_layer = layers.Dropout(0.25, name='dropout', weights=model.dropout_layer.get_weights())
        self.fc_layers = [
            layers.Dense(512, activation=tf.nn.relu, name='fc_0', weights=model.fc_layers[0].get_weights()),
            layers.Dense(512, activation=tf.nn.relu, name='fc_1', weights=model.fc_layers[1].get_weights()),
            layers.Dense(NUM_FRAMES//2, name='fc_layers_2', weights=model.fc_layers[2].get_weights()),
        ]
        
        self.within_period_fc_layers = [
            layers.Dense(512, activation=tf.nn.relu, name='within_period_fc_0', weights=model.within_period_fc_layers[0].get_weights()),
            layers.Dense(512, activation=tf.nn.relu, name='within_period_fc_1', weights=model.within_period_fc_layers[1].get_weights()),
            layers.Dense(1, name='within_period_fc_2', weights=model.within_period_fc_layers[2].get_weights()),
        ]
                                            
    def call(self, x):
        x = get_sims(x, SIM_TEMPERATURE)
        x = self.conv_3x3_layer(x)
        x = tf.reshape(x, [BATCH_SIZE, NUM_FRAMES, -1])
        within_period_x = x
        
        x = self.input_projection1(x)
        x += self.pos_encoding1
        x = self.transformer_layer1(x)
        x = flatten_sequential_feats(x, BATCH_SIZE, NUM_FRAMES)

        for fc_layer in self.fc_layers:
            x = self.dropout_layer(x)
            x = fc_layer(x)
            
        within_period_x = self.input_projection2(within_period_x)
        within_period_x += self.pos_encoding2
        within_period_x = self.transformer_layer2(within_period_x)
        within_period_x = flatten_sequential_feats(within_period_x, BATCH_SIZE, NUM_FRAMES)

        for fc_layer in self.within_period_fc_layers:
            within_period_x = self.dropout_layer(within_period_x)
            within_period_x = fc_layer(within_period_x)
            
        return x, within_period_x

In [70]:
model3 = ResnetPart3(model)
model_part3_outputs = model3(model_part2_outputs)
tf.reshape(model_part3_outputs[0], (-1,))[:3], '-'*60, tf.reshape(test_outputs[0], (-1,))[:3]

(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-12.34874  ,   2.2500176,  -1.090362 ], dtype=float32)>,
 '------------------------------------------------------------',
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([-12.34874 ,   2.250026,  -1.090359], dtype=float32)>)

In [71]:
model3.summary(line_length=90)

Model: "ResnetPart3"
__________________________________________________________________________________________
Layer (type)                            Output Shape                        Param #       
conv_3x3_layer (Conv2D)                 multiple                            320           
__________________________________________________________________________________________
input_projection1 (Dense)               multiple                            1049088       
__________________________________________________________________________________________
input_projection2 (Dense)               multiple                            1049088       
__________________________________________________________________________________________
transformer_layer1 (TransformerLayer)   multiple                            1577984       
__________________________________________________________________________________________
transformer_layer2 (TransformerLayer)   multiple                     

In [79]:
model3.save(filepath=f'{SAVED_MODEL_ROOT}/model3', save_format='tf')



INFO:tensorflow:Assets written to: /data/nb_data/split_repnet_model/model3/assets


INFO:tensorflow:Assets written to: /data/nb_data/split_repnet_model/model3/assets


### Test

In [72]:
outputs = model3(model2(model1(test_inputs_tensor)))
outputs[0].shape, outputs[1].shape

(TensorShape([1, 64, 32]), TensorShape([1, 64, 1]))

In [73]:
tf.reshape(outputs[0], (-1,))[:3], '-'*60, tf.reshape(test_outputs[0], (-1,))[:3]

(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-12.34874  ,   2.2500176,  -1.090362 ], dtype=float32)>,
 '------------------------------------------------------------',
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([-12.34874 ,   2.250026,  -1.090359], dtype=float32)>)

In [74]:
tf.reshape(outputs[1], (-1,))[:3], '-'*60, tf.reshape(test_outputs[1], (-1,))[:3]

(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([1.4859768, 1.630618 , 1.6552675], dtype=float32)>,
 '------------------------------------------------------------',
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1.4859803, 1.6306198, 1.6552691], dtype=float32)>)

## Load Model and Test

In [101]:
model2_loaded = tf.saved_model.load(f'{SAVED_MODEL_ROOT}/model2')
model2_graph_infer = model2_loaded.signatures['serving_default']

In [105]:
tf.reshape(model2_graph_infer(model_part1_outputs)['output_1'], (-1,))[:3], tf.reshape(model_part2_outputs, (-1,))[:3]

(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.       , 1.0784211, 0.7234727], dtype=float32)>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.       , 1.0784211, 0.7234727], dtype=float32)>)

In [104]:
tf.reshape(model2_loaded(model_part1_outputs), (-1,))[:3], tf.reshape(model_part2_outputs, (-1,))[:3]

(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.       , 1.0784211, 0.7234727], dtype=float32)>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.       , 1.0784211, 0.7234727], dtype=float32)>)