Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
157 lines (132 sloc) 5.09 KB
import logging
logging.basicConfig(format="[%(asctime)s] %(message)s", datefmt="%m-%d %H:%M:%S")
import os
import sys
import urllib
import pprint
import tarfile
import tensorflow as tf
import datetime
import numpy as np
import scipy.misc
import json
pp = pprint.PrettyPrinter().pprint
logger = logging.getLogger(__name__)
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def mprint(matrix, pivot=0.5):
for array in matrix:
print("".join("#" if i > pivot else " " for i in array))
def show_all_variables():
total_count = 0
for idx, op in enumerate(tf.trainable_variables()):
shape = op.get_shape()
count =
print("[%2d] %s %s = %s" % (idx,, shape, count))
total_count += int(count)
print("[Total] variable size: %s" % "{:,}".format(total_count))
def get_timestamp():
now =
return now.strftime('%Y_%m_%d_%H_%M_%S')
def binarize(images):
return (np.random.uniform(size=images.shape) < images).astype('float32')
def save_images(images, height, width, n_row, n_col,
cmin=0.0, cmax=1.0, directory="./", prefix="sample"):
images = images.reshape((n_row, n_col, height, width))
images = images.transpose(1, 2, 0, 3)
images = images.reshape((height * n_row, width * n_col))
filename_tmplte = '%s_%s.jpg'
i = 0
while os.path.exists(os.path.join(directory, filename_tmplte % (prefix, i))):
i += 1
filename = filename_tmplte % (prefix, i)
scipy.misc.toimage(images, cmin=cmin, cmax=cmax) \
.save(os.path.join(directory, filename))
return os.path.join(directory, filename)
def occlude(images, height, width):
assert images.shape == (len(images), height, width, 1), 'shape doesn\'t match expected shape'
samples = np.zeros((len(images), height, width, 1), dtype='float32')
samples[:,:height//2] = images[:, :height//2]
return samples
def binarize(images):
rand = np.random.uniform(size=images.shape)
return (rand < images).astype('float32')
def load_images(dataset_name, normalize=True):
if dataset_name == 'mnist':
dataset = mnist.load_mnist()
image_height, image_width, num_channels = MNIST_PARAMS
next_train_batch = lambda x: dataset.train.next_batch(x)[0]
next_test_batch = lambda x: dataset.test.next_batch(x)[0]
elif dataset_name == 'cifar':
dataset = cifar10.load_cifar()
image_height, image_width, num_channels = CIFAR_PARAMS
# next_train_batch
# next_test_batch
raise('{0} is not a supported dataset'.format(dataset_name))
return dataset, image_height, image_width, num_channels, next_train_batch, next_test_batch
def get_shape(tensor):
return tensor.get_shape().as_list()
def get_model_dir(config, exceptions=None):
attrs = config.__dict__['__flags']
keys = attrs.keys()
keys = ['data'] + keys
names =[]
for key in keys:
# Only use useful flags
if key not in exceptions:
names.append("%s=%s" % (key, ",".join([str(i) for i in attrs[key]])
if type(attrs[key]) == list else attrs[key]))
return os.path.join('checkpoints', *names) + '/'
def setup_model_saving(model_name, data, hyperparams=None, root_dir='run/'):
# construct the model directory template name
name = os.path.join(root_dir, data, model_name + '%s')
# iterate until we find an index that hasn't been taken yet.
i = 0
while os.path.exists(name % i):
i += 1
name = name % i
# create the folder
if hyperparams is not None:
# TODO Save the hyperparameters as a file
with open(os.path.join(name, 'params.json'), 'w') as jsonfile:
json.dump(hyperparams, jsonfile)
return name
def preprocess_conf(conf):
options = conf.__flags
for option, value in options.items():
option = option.lower()
def check_and_create_dir(directory):
if not os.path.exists(directory):'Creating directory: %s' % directory)
else:'Skip creating directory: %s' % directory)
def maybe_download_and_extract(dest_directory):
Download and extract the tarball from Alex's website.
if not os.path.exists(dest_directory):
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
float(count * block_size) / float(total_size) * 100.0))
filepath, _ = urllib.urlretrieve(DATA_URL, filepath, _progress)
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.'), 'r:gz').extractall(dest_directory)