In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
from get_data import get_memristor_data
from main import MemristorAutoEncoder
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'gray'
%matplotlib inline

In [2]:
path = 'Data/Partial_Reset_PCM.pkl'

n_mem = 400
norm_min, norm_max = -0.9, 0.9

(vs_data, mus_data, sigs_data, 
 orig_VMIN, orig_VMAX, orig_RMIN, orig_RMAX) = get_memristor_data(path, n_mem, norm_min=norm_min, norm_max=norm_max)

n_samp, n_m = vs_data.shape

In [3]:
dataset = ['mnist', 'imagenet'][1]

In [None]:
if dataset == 'mnist':
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

    from utils import DataAndNoiseGenerator
    train_data_obj = DataAndNoiseGenerator(mnist.train.images, n_m)
    summary_data_obj = DataAndNoiseGenerator(mnist.test.images, n_m)
    _, data_dim = mnist.train.images.shape
    image_shape = (28, 28)
    data_obj = DataAndNoiseGenerator(mnist.test.images, n_m)
    
elif dataset == 'imagenet':
    import os
    from utils import FileAndNoiseGenerator
    
    data_partition = ['train', 'valid'][0]
    size = [32, 64][0]
    file_directory = '/home/aga/imagenet_data/{}_{}x{}'.format(data_partition, size, size)
    file_list = os.listdir(file_directory)
    n_train = 60000
    n_test  = 1000
    train_data_obj = FileAndNoiseGenerator(
        file_list[0:n_train], file_directory, n_m)
    summary_data_obj = FileAndNoiseGenerator(
        file_list[n_train:n_train + n_test], file_directory, n_m, max_batches=1)
    data_dim = size ** 2 * 3
    data_obj = summary_data_obj
    image_shape = (32, 32, 3)
    
    
else: 
    raise ValueError('Invalid Dataset {}'.format(dataset))

In [None]:
mae = MemristorAutoEncoder(
    gamma=10,
    data_dim=data_dim,
    memristor_data={
        'vs_data': vs_data,
        'mus_data': mus_data,
        'sigs_data': sigs_data,
        'vmin': norm_min,
        'vmax': norm_max,
        'orig_v_range': (orig_VMIN, orig_VMAX), 
        'orig_r_range': (orig_RMIN, orig_RMAX)
    },
    encoder_params={
        'layer_sizes': [data_dim, 1000, 400, n_m],
        'non_linearity': 'tanh'
    },
    decoder_params={
        'layer_sizes': [n_m, 400, 1000, data_dim],
        'non_linearity': 'tanh'
    },
    optimizer_params={
        'batch_size': 50,
        'num_epochs': 50,
        'method': 'adam',
        'learning_rate': 0.001
    },
    output_dir='output',
    param_file=None)
#     param_file='tmp/model.ckpt')

In [None]:
mae.fit(train_data_obj, summary_data_obj)

In [None]:
eval_vals = mae.inspect_network(data_obj)

In [None]:
def RV_density_plot(Vs, Rs, cmap, bins=300):
    v_flat = np.ravel(Vs)
    r_flat = np.ravel(Rs)

    r_indx = np.isfinite(r_flat)

    # v_flat = np.nan_to_num(v_flat)
    v_flat = v_flat[r_indx]
    r_flat = r_flat[r_indx]

    heatmap, xedges, yedges = np.histogram2d(v_flat, r_flat, bins=bins)
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

    plt.figure(figsize=(5,10))
    plt.imshow(heatmap.T, extent=extent, origin='lower',interpolation='nearest',
               cmap=cmap)
#     plt.colorbar()
#     plt.show()

In [None]:
RV_density_plot(eval_vals['v'], eval_vals['r'], cmap=plt.cm.jet)

In [None]:
plt.scatter(eval_vals['v'], eval_vals['r'])

In [None]:
plt.hist(eval_vals['v'].ravel(), bins=50)

In [None]:
RV_density_plot(eval_vals['v'], eval_vals['r']-3, cmap=plt.cm.jet)

plt.hist(eval_vals['v'].ravel(), bins=50, normed=True);

In [None]:
(eval_vals['x'] ** 2).sum(axis=-1) / ((eval_vals['x'] - eval_vals['xh']) ** 2).sum(axis=-1)

In [None]:
q_ = [0, 10, 12]
n_imgs = len(q_)
plt.figure(figsize=(15, 4 * n_imgs))
for i, q in enumerate(q_):
    plt.subplot(n_imgs, 3, 1 + 3 * i)
    plt.imshow(eval_vals['x'][q].reshape(*image_shape))
    plt.colorbar()
    plt.subplot(n_imgs, 3, 2 + 3 * i)
    plt.imshow(eval_vals['xh'][q].reshape(image_shape), vmin=0, vmax=1)

    plt.colorbar()

    plt.subplot(n_imgs, 3, 3 + 3 * i)
    plt.scatter(eval_vals['v'][q], eval_vals['r'][q])
    plt.xlabel('V')
    plt.ylabel('log(R)')

In [None]:
x = eval_vals['x']
xh = eval_vals['xh']

In [None]:
def snr(u, v):
    return np.mean(
        (u ** 2).mean(axis=1) / 
        ((u - v) ** 2).mean(axis=1)
    )

In [None]:
snr(x, xh)

In [None]:
snr(x, np.clip(xh, 0, 1))

In [None]:
def RV_density_plot(Vs, Rs, cmap, bins=300):
    v_flat = np.ravel(Vs)
    r_flat = np.ravel(Rs)

#     r_indx = np.isfinite(r_flat)

    # v_flat = np.nan_to_num(v_flat)
#     v_flat = v_flat[r_indx]
#     r_flat = r_flat[r_indx]

    heatmap, xedges, yedges = np.histogram2d(v_flat, r_flat, bins=bins)
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

    plt.figure(figsize=(5,10))
    plt.imshow(heatmap.T, extent=extent, origin='lower',interpolation='nearest',
               cmap=cmap)
#     plt.colorbar()
#     plt.show()

In [None]:
RV_density_plot(eval_vals['v'][:, 0], eval_vals['v'][:, 3], plt.cm.jet, bins=100)

In [None]:
plt.scatter()
plt.axis('equal')

In [None]:
28 * 28 * 8

In [None]:
plt.imsave?

In [None]:
plt.imsave('img.jpg', eval_vals['x'][0].reshape(28, 28))

In [None]:
pwd

In [None]:
image_dir = '/home/aga/imagenet_data/train_64x64'

In [None]:
import os

In [None]:
files = os.listdir(image_dir)

In [None]:
files.sort()

In [None]:
files[0]

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
img = plt.imread(os.path.join(image_dir, '0000010.png'))

In [None]:
files[-1]

In [None]:
img.shape

In [None]:
plt.imshow(img)

In [None]:
img.size

In [None]:
import numpy as np

In [None]:
n_files = 1281149

In [None]:
image_dir

In [None]:
'/home/aga/imagenet_data/train_64x64/{:07d}.png'.format(10)

In [None]:
pattern = '/home/aga/imagenet_data/train_64x64/{:07d}.png'

In [None]:
def pattern_batch_generator(pattern, n_files, batch_size=1000):
    num_batches = n_files / batch_size
    
    for i in range(num_batches):
        batch = None
        for j in range(batch_size * i, batch_size * (i + 1)):
            img = plt.imread(pattern.format(j + 1))
            if batch is None:
                n_features = img.size
                batch = np.zeros((batch_size, n_features))
            batch[j % batch_size] = img.ravel()
        yield batch

In [None]:
files = os.listdir('/home/aga/imagenet_data/train_32x32/')

In [None]:
files.sort()

In [None]:
files[-1]

In [None]:
gen = pattern_batch_generator(
    pattern='/home/aga/imagenet_data/train_32x32/{:07d}.png',
    n_files=1281150)

In [None]:
%%time
for batch in gen:
    print batch.mean()
    break

In [None]:
def file_batch_generator(files, batch_size, directory, max_batches=100):
    n_samples = len(files)
    num_batches = n_samples / batch_size
    
    for i in range(num_batches):
        if i >= max_batches:
            break
        file_batch = files[(i + 0) * batch_size: 
                           (i + 1) * batch_size]
        batch = None
        for j, fn in enumerate(file_batch):
            img = plt.imread(os.path.join(directory, fn))
            if batch is None:
                n_features = img.size
                batch = np.zeros((batch_size, n_features))
            batch[j] = img.ravel()
        yield batch

In [None]:
gen = file_batch_generator(files, 10, image_dir)

In [None]:
files[1]

In [None]:
%%time
for datum in gen:
    print datum.mean()
    break

In [None]:
datum.mean