Skip to content

Commit

Permalink
[TEST] TensorFlow 10, 11 Python 2, 3
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Oct 29, 2016
1 parent 037d2a7 commit 3a44af1
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 20 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ If you already had the pre-requisites ready, the simplest way to install TensorL


```bash
[for stable version] pip install tensorlayer==1.2.2
[for stable version] pip install tensorlayer==1.2.3
[for master version] pip install git+https://github.com/zsdonghao/tensorlayer.git
```

Expand Down
6 changes: 6 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

from __future__ import absolute_import


from . import imagenet_classes
# from . import
15 changes: 11 additions & 4 deletions tensorlayer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2671,15 +2671,22 @@ def __init__(
layer = None,
slim_layer = None,
slim_args = {},
name ='slim_layer',
name ='InceptionV3',
):
Layer.__init__(self, name=name)
self.inputs = layer.outputs
print(" tensorlayer:Instantiate SlimNetsLayer %s: %s" % (self.name, slim_layer.__name__))

with tf.variable_scope(name) as vs:
net, end_points = slim_layer(self.inputs, **slim_args)
slim_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)
# with tf.variable_scope(name) as vs:
# net, end_points = slim_layer(self.inputs, **slim_args)
# slim_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)

net, end_points = slim_layer(self.inputs, **slim_args)

slim_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=name)
if slim_variables == []:
print("No variables found under %s : the name of SlimNetsLayer should be matched with the begining of the ckpt file, see tutorial_inceptionV3_tfslim.py for more details" % name)


self.outputs = net

Expand Down
6 changes: 3 additions & 3 deletions tensorlayer/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def W(W=None, second=10, saveable=True, shape=[28,28], name='mnist', fig_idx=239
# feature = np.zeros_like(feature)
plt.imshow(np.reshape(feature ,(shape[0],shape[1])),
cmap='gray', interpolation="nearest")#, vmin=np.min(feature), vmax=np.max(feature))
plt.title(name)
# plt.title(name)
# ------------------------------------------------------------
# plt.imshow(np.reshape(W[:,count-1] ,(np.sqrt(size),np.sqrt(size))), cmap='gray', interpolation="nearest")
plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick
Expand Down Expand Up @@ -223,11 +223,11 @@ def images2d(images=None, second=10, saveable=True, name='images', dtype=None,
plt.imshow(
np.reshape(images[count-1,:,:], (n_row, n_col)),
cmap='gray', interpolation="nearest")
plt.title(name)
# plt.title(name)
elif n_color == 3:
plt.imshow(images[count-1,:,:],
cmap='gray', interpolation="nearest")
plt.title(name)
# plt.title(name)
else:
raise Exception("Unknown n_color")
plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick
Expand Down
11 changes: 8 additions & 3 deletions tutorial_generate_text.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Copyright 2016 TensorLayer. All Rights Reserved.
#! /usr/bin/python
# -*- coding: utf8 -*-



# Copyright 2016 TensorLayer. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -264,7 +269,7 @@ def inference(x, is_training, num_steps, reuse=None):
network = tl.layers.DropoutLayer(network, keep=keep_prob, name='drop1')
network = tl.layers.RNNLayer(network,
cell_fn=tf.nn.rnn_cell.BasicLSTMCell,
cell_init_args={'forget_bias': 0.0},# 'state_is_tuple': True},
cell_init_args={'forget_bias': 0.0, 'state_is_tuple': True},
n_hidden=hidden_size,
initializer=tf.random_uniform_initializer(-init_scale, init_scale),
n_steps=num_steps,
Expand All @@ -275,7 +280,7 @@ def inference(x, is_training, num_steps, reuse=None):
network = tl.layers.DropoutLayer(network, keep=keep_prob, name='drop2')
network = tl.layers.RNNLayer(network,
cell_fn=tf.nn.rnn_cell.BasicLSTMCell,
cell_init_args={'forget_bias': 0.0}, # 'state_is_tuple': True},
cell_init_args={'forget_bias': 0.0, 'state_is_tuple': True},
n_hidden=hidden_size,
initializer=tf.random_uniform_initializer(-init_scale, init_scale),
n_steps=num_steps,
Expand Down
2 changes: 1 addition & 1 deletion tutorial_inceptionV3_tfslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def print_prob(prob):
# 'reuse' : None,
# 'scope' : 'InceptionV3'
},
name=''
name='InceptionV3' # <-- the name should be the same with the ckpt model
)
saver = tf.train.Saver()

Expand Down
4 changes: 2 additions & 2 deletions tutorial_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def main_test_layers(model='relu'):

params = network.all_params
# train
n_epoch = 1
n_epoch = 100
batch_size = 128
learning_rate = 0.0001
print_freq = 10
print_freq = 5
train_op = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999,
epsilon=1e-08, use_locking=False).minimize(cost)

Expand Down
5 changes: 3 additions & 2 deletions tutorial_tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import tensorflow as tf
import tensorlayer as tl
import os
import numpy as np
from PIL import Image
import io
import os


"""
Expand Down Expand Up @@ -67,7 +68,7 @@
label = example.features.feature['label'].int64_list.value
## converts a image from bytes
image = Image.frombytes('RGB', (224, 224), img_raw[0])
tl.visualize.frame(image, second=0.5, saveable=False, name='frame', fig_idx=1283)
tl.visualize.frame(np.asarray(image), second=0.5, saveable=False, name='frame', fig_idx=1283)
print(label)


Expand Down
2 changes: 1 addition & 1 deletion tutorial_tfrecord2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
## Visualize a image
# tl.visualize.frame(np.asarray(img, dtype=np.uint8), second=1, saveable=False, name='frame', fig_idx=1236)
label = int(y_train[index])
print(label)
# print(label)
## Convert the bytes back to image as follow:
# image = Image.frombytes('RGB', (32, 32), img_raw)
# image = np.fromstring(img_raw, np.float32)
Expand Down
12 changes: 9 additions & 3 deletions tutorial_tfrecord3.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _bytes_feature_list(values):
c = tf.contrib.learn.run_n(features, n=1, feed_dict=None)
from PIL import Image
im = Image.frombytes('RGB', (299, 299), c[0]['image/img_raw'])
tl.visualize.frame(im, second=1, saveable=False, name='frame', fig_idx=1236)
tl.visualize.frame(np.asarray(im), second=1, saveable=False, name='frame', fig_idx=1236)
c = tf.contrib.learn.run_n(sequence_features, n=1, feed_dict=None)
print(c[0])

Expand Down Expand Up @@ -334,10 +334,16 @@ def prefetch_input_data(reader,
img = tf.decode_raw(context["image/img_raw"], tf.uint8)
img = tf.reshape(img, [height, width, 3])
img = tf.image.convert_image_dtype(img, dtype=tf.float32)
# for TensorFlow 0.10
# img = tf.image.resize_images(img,
# new_height=resize_height,
# new_width=resize_width,
# method=tf.image.ResizeMethod.BILINEAR)
# for TensorFlow 0.11
img = tf.image.resize_images(img,
new_height=resize_height,
new_width=resize_width,
size=(resize_height, resize_width),
method=tf.image.ResizeMethod.BILINEAR)

# Crop to final dimensions.
if is_training:
img = tf.random_crop(img, [height, width, 3])
Expand Down

0 comments on commit 3a44af1

Please sign in to comment.