In [1]:
import matplotlib as mlp
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras 
import warnings

warnings.filterwarnings('ignore')
print(tf.__version__)
for model in sklearn, pd, keras, np ,mlp:
    print(model.__name__, model.__version__)

2.1.0
sklearn 0.20.2
pandas 0.24.2
tensorflow_core.python.keras.api._v2.keras 2.2.4-tf
numpy 1.17.4
matplotlib 2.1.2


In [5]:
# 定义Python 函数
def scaled_elu(z, scale=1.0, alpha=1.0):
    is_positive = tf.greater_equal(z, 0)
    return scale * tf.where(is_positive, z, alpha * tf.nn.elu(z))

print(scaled_elu(tf.constant(3.)))
print(scaled_elu(tf.constant(-3.)))

# 将Python 函数转换为 tf.graph，速度快
scaled_elu_tf = tf.function(scaled_elu)
print(scaled_elu_tf(tf.constant(3.)))
print(scaled_elu_tf(tf.constant(-3.)))

tf.Tensor(3.0, shape=(), dtype=float32)
tf.Tensor(-0.95021296, shape=(), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)
tf.Tensor(-0.95021296, shape=(), dtype=float32)


In [6]:
%timeit scaled_elu(tf.random.normal((1000, 1000)))
%timeit scaled_elu_tf(tf.random.normal((1000, 1000)))

5.44 ms ± 355 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.45 ms ± 29.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
# 另外一种将Python函数转化为tf.graph的方式就是注释的方式
@tf.function
def conver_to_2(n_iter):
    total = tf.constant(0.0)
    increment = tf.constant(1.)
    for _ in range(n_iter):
        total += increment
        increment /= 2
    return total

print(conver_to_2(20))
        

tf.Tensor(1.9999981, shape=(), dtype=float32)


In [11]:
# 针对变量，在Python函数转化的时候，是需要把变量定义在函数外的，否则会报错
var = tf.Variable(1.)
@tf.function
def add_10():
    return var.assign_add(10) # +=

print(add_10())

tf.Tensor(11.0, shape=(), dtype=float32)


In [12]:
# 在签名注释中，可以限定输入的参数类型
@tf.function(input_signature=[tf.TensorSpec([None], tf.float32, name='x')])
def cube(z):
    return tf.pow(z, 3)

print(cube(tf.constant([1.,2.,3.])))

tf.Tensor([ 1.  8. 27.], shape=(3,), dtype=float32)
