Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ TensorLayer is a deep learning and reinforcement learning library on top of [Ten
- Useful links: [Documentation](http://tensorlayer.readthedocs.io), [Examples](http://tensorlayer.readthedocs.io/en/latest/user/example.html), [中文文档](https://tensorlayercn.readthedocs.io), [中文书](http://www.broadview.com.cn/book/5059)

# News
* [16 Mar] Release experimental APIs for binary networks.
* [18 Jan] [《深度学习:一起玩转TensorLayer》](http://www.broadview.com.cn/book/5059) (Deep Learning using TensorLayer)
* [17 Dec] Release experimental APIs for distributed training (by [TensorPort](https://tensorport.com)). See [tiny example](https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_mnist_distributed.py).
* [17 Nov] Release data augmentation APIs for object detection, see [tl.prepro](http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html#object-detection).
Expand Down
5 changes: 5 additions & 0 deletions docs/modules/files.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ API - Files
load_mnist_dataset
load_fashion_mnist_dataset
load_cifar10_dataset
load_cropped_svhn
load_ptb_dataset
load_matt_mahoney_text8_dataset
load_imdb_dataset
Expand Down Expand Up @@ -63,6 +64,10 @@ CIFAR-10
^^^^^^^^^^^^
.. autofunction:: load_cifar10_dataset

SVHN
^^^^^^^
.. autofunction:: load_cropped_svhn

Penn TreeBank (PTB)
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: load_ptb_dataset
Expand Down
134 changes: 120 additions & 14 deletions tensorlayer/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import sys
import tarfile
import zipfile
import time
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another import time on #L1130 could be removed now.


import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -320,6 +321,106 @@ def unpickle(file):
return X_train, y_train, X_test, y_test


def load_cropped_svhn(path='data', include_extra=True):
"""Load Cropped SVHN.

The Cropped Street View House Numbers (SVHN) Dataset contains 32x32x3 RGB images.
Digit '1' has label 1, '9' has label 9 and '0' has label 0 (the original dataset uses 10 to represent '0'), see `ufldl website <http://ufldl.stanford.edu/housenumbers/>`__.

Parameters
----------
path : str
The path that the data is downloaded to.
include_extra : boolean
If True (default), add extra images to the training set.

Returns
-------
X_train, y_train, X_test, y_test: tuple
Return splitted training/test set respectively.

Examples
---------
>>> X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False)
>>> tl.vis.save_images(X_train[0:100], [10, 10], 'svhn.png')

"""

import scipy.io

start_time = time.time()

path = os.path.join(path, 'cropped_svhn')
logging.info("Load or Download Cropped SVHN > {} | include extra images: {}".format(path, include_extra))
url = "http://ufldl.stanford.edu/housenumbers/"

np_file = os.path.join(path, "train_32x32.npz")
if file_exists(np_file) is False:
filename = "train_32x32.mat"
filepath = maybe_download_and_extract(filename, path, url)
mat = scipy.io.loadmat(filepath)
X_train = mat['X'] / 255.0 # to [0, 1]
X_train = np.transpose(X_train, (3, 0, 1, 2))
y_train = np.squeeze(mat['y'], axis=1)
y_train[y_train == 10] = 0 # replace 10 to 0
np.savez(np_file, X=X_train, y=y_train)
del_file(filepath)
else:
v = np.load(np_file)
X_train = v['X']
y_train = v['y']
logging.info(" n_train: {}".format(len(y_train)))

np_file = os.path.join(path, "test_32x32.npz")
if file_exists(np_file) is False:
filename = "test_32x32.mat"
filepath = maybe_download_and_extract(filename, path, url)
mat = scipy.io.loadmat(filepath)
X_test = mat['X'] / 255.0
X_test = np.transpose(X_test, (3, 0, 1, 2))
y_test = np.squeeze(mat['y'], axis=1)
y_test[y_test == 10] = 0
np.savez(np_file, X=X_test, y=y_test)
del_file(filepath)
else:
v = np.load(np_file)
X_test = v['X']
y_test = v['y']
logging.info(" n_test: {}".format(len(y_test)))

if include_extra:
logging.info(" getting extra 531131 images, please wait ...")
np_file = os.path.join(path, "extra_32x32.npz")
if file_exists(np_file) is False:
logging.info(" the first time to load extra images will take long time to convert the file format ...")
filename = "extra_32x32.mat"
filepath = maybe_download_and_extract(filename, path, url)
mat = scipy.io.loadmat(filepath)
X_extra = mat['X'] / 255.0
X_extra = np.transpose(X_extra, (3, 0, 1, 2))
y_extra = np.squeeze(mat['y'], axis=1)
y_extra[y_extra == 10] = 0
np.savez(np_file, X=X_extra, y=y_extra)
del_file(filepath)
else:
v = np.load(np_file)
X_extra = v['X']
y_extra = v['y']
# print(X_train.shape, X_extra.shape)
logging.info(" adding n_extra {} to n_train {}".format(len(y_extra), len(y_train)))
t = time.time()
X_train = np.concatenate((X_train, X_extra), 0)
y_train = np.concatenate((y_train, y_extra), 0)
# X_train = np.append(X_train, X_extra, axis=0)
# y_train = np.append(y_train, y_extra, axis=0)
logging.info(" added n_extra {} to n_train {} took {}s".format(len(y_extra), len(y_train), time.time() - t))
else:
logging.info(" no extra images are included")
logging.info(" image size:%s n_train:%d n_test:%d" % (str(X_train.shape[1:4]), len(y_train), len(y_test)))
logging.info(" took: {}s".format(int(time.time() - start_time)))
return X_train, y_train, X_test, y_test


def load_ptb_dataset(path='data'):
"""Load Penn TreeBank (PTB) dataset.

Expand Down Expand Up @@ -656,19 +757,19 @@ def load_flickr25k_dataset(tag='sky', path="data", n_threads=50, printable=False
url = 'http://press.liacs.nl/mirflickr/mirflickr25k/'

# download dataset
if folder_exists(path + "/mirflickr") is False:
if folder_exists(os.path.join(path, "mirflickr")) is False:
logging.info("[*] Flickr25k is nonexistent in {}".format(path))
maybe_download_and_extract(filename, path, url, extract=True)
del_file(path + '/' + filename)
del_file(os.path.join(path, filename))

# return images by the given tag.
# 1. image path list
folder_imgs = path + "/mirflickr"
folder_imgs = os.path.join(path, "mirflickr")
path_imgs = load_file_list(path=folder_imgs, regx='\\.jpg', printable=False)
path_imgs.sort(key=natural_keys)

# 2. tag path list
folder_tags = path + "/mirflickr/meta/tags"
folder_tags = os.path.join(path, "mirflickr", "meta", "tags")
path_tags = load_file_list(path=folder_tags, regx='\\.txt', printable=False)
path_tags.sort(key=natural_keys)

Expand All @@ -679,7 +780,7 @@ def load_flickr25k_dataset(tag='sky', path="data", n_threads=50, printable=False
logging.info("[Flickr25k] reading images with tag: {}".format(tag))
images_list = []
for idx, _v in enumerate(path_tags):
tags = read_file(folder_tags + '/' + path_tags[idx]).split('\n')
tags = read_file(os.path.join(folder_tags, path_tags[idx])).split('\n')
# logging.info(idx+1, tags)
if tag is None or tag in tags:
images_list.append(path_imgs[idx])
Expand Down Expand Up @@ -722,6 +823,8 @@ def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printab
>>> images = tl.files.load_flickr1M_dataset(tag='zebra')

"""
import shutil
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should almost always put import on top of file:
https://stackoverflow.com/questions/1188640/good-or-bad-practice-in-python-import-in-the-middle-of-a-file
shutil is a standard module in python, it should be cheap to import.


path = os.path.join(path, 'flickr1M')
logging.info("[Flickr1M] using {}% of images = {}".format(size * 10, size * 100000))
images_zip = [
Expand All @@ -734,20 +837,21 @@ def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printab
for image_zip in images_zip[0:size]:
image_folder = image_zip.split(".")[0]
# logging.info(path+"/"+image_folder)
if folder_exists(path + "/" + image_folder) is False:
if folder_exists(os.path.join(path, image_folder)) is False:
# logging.info(image_zip)
logging.info("[Flickr1M] {} is missing in {}".format(image_folder, path))
maybe_download_and_extract(image_zip, path, url, extract=True)
del_file(path + '/' + image_zip)
os.system("mv {} {}".format(path + '/images', path + '/' + image_folder))
del_file(os.path.join(path, image_zip))
# os.system("mv {} {}".format(os.path.join(path, 'images'), os.path.join(path, image_folder)))
shutil.move(os.path.join(path, 'images'), os.path.join(path, image_folder))
else:
logging.info("[Flickr1M] {} exists in {}".format(image_folder, path))

# download tag
if folder_exists(path + "/tags") is False:
if folder_exists(os.path.join(path, "tags")) is False:
logging.info("[Flickr1M] tag files is nonexistent in {}".format(path))
maybe_download_and_extract(tag_zip, path, url, extract=True)
del_file(path + '/' + tag_zip)
del_file(os.path.join(path, tag_zip))
else:
logging.info("[Flickr1M] tags exists in {}".format(path))

Expand All @@ -761,17 +865,19 @@ def load_flickr1M_dataset(tag='sky', size=10, path="data", n_threads=50, printab
for folder in images_folder_list[0:size * 10]:
tmp = load_file_list(path=folder, regx='\\.jpg', printable=False)
tmp.sort(key=lambda s: int(s.split('.')[-2])) # ddd.jpg
images_list.extend([folder + '/' + x for x in tmp])
images_list.extend([os.path.join(folder, x) for x in tmp])

# 2. tag path list
tag_list = []
tag_folder_list = load_folder_list(path + "/tags")
tag_folder_list.sort(key=lambda s: int(s.split('/')[-1])) # folder/images/ddd
tag_folder_list = load_folder_list(os.path.join(path, "tags"))

# tag_folder_list.sort(key=lambda s: int(s.split("/")[-1])) # folder/images/ddd
tag_folder_list.sort(key=lambda s: int(os.path.basename(s)))

for folder in tag_folder_list[0:size * 10]:
tmp = load_file_list(path=folder, regx='\\.txt', printable=False)
tmp.sort(key=lambda s: int(s.split('.')[-2])) # ddd.txt
tmp = [folder + '/' + s for s in tmp]
tmp = [os.path.join(folder, s) for s in tmp]
tag_list += tmp

# 3. select images
Expand Down
8 changes: 4 additions & 4 deletions tensorlayer/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
]


def read_image(image, path=''):
def read_image(image, path='_.png'):
"""Read one image.

Parameters
Expand All @@ -44,7 +44,7 @@ def read_image(image, path=''):
return scipy.misc.imread(os.path.join(path, image))


def read_images(img_list, path='', n_threads=10, printable=True):
def read_images(img_list, path='_.png', n_threads=10, printable=True):
"""Returns all images in list by given path and name of each image file.

Parameters
Expand Down Expand Up @@ -75,7 +75,7 @@ def read_images(img_list, path='', n_threads=10, printable=True):
return imgs


def save_image(image, image_path=''):
def save_image(image, image_path='_temp.png'):
"""Save a image.

Parameters
Expand All @@ -92,7 +92,7 @@ def save_image(image, image_path=''):
scipy.misc.imsave(image_path, image[:, :, 0])


def save_images(images, size, image_path=''):
def save_images(images, size, image_path='_temp.png'):
"""Save multiple images into one single image.

Parameters
Expand Down