In [1]:
# reload packages
%load_ext autoreload
%autoreload 2

### Choose GPU (this may not be needed on your computer)

In [2]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1


In [3]:
import tensorflow as tf
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

### load packages

In [4]:
from tfumap.umap import tfUMAP



In [5]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
import umap
import pandas as pd

### Load dataset

In [6]:
from tensorflow.keras.datasets import cifar10

In [7]:
# load dataset
(train_images, Y_train), (test_images, Y_test) = cifar10.load_data()
X_train = (train_images/255.).astype('float32')
X_test = (test_images/255.).astype('float32')
X_train = X_train.reshape((len(X_train), np.product(np.shape(X_train)[1:])))
X_test = X_test.reshape((len(X_test), np.product(np.shape(X_test)[1:])))

# subset a validation set
n_valid = 10000
X_valid = X_train[-n_valid:]
Y_valid = Y_train[-n_valid:]
X_train = X_train[:-n_valid]
Y_train = Y_train[:-n_valid]

# flatten X
X_train_flat = X_train.reshape((len(X_train), np.product(np.shape(X_train)[1:])))
X_test_flat = X_test.reshape((len(X_test), np.product(np.shape(X_test)[1:])))
X_valid_flat= X_valid.reshape((len(X_valid), np.product(np.shape(X_valid)[1:])))
print(len(X_train), len(X_valid), len(X_test))

40000 10000 10000


### define networks

In [8]:
dims = (32,32,3)
n_components = 64

In [9]:
encoder = tf.keras.Sequential([
    tf.keras.layers.InputLayer(input_shape=dims),
    tf.keras.layers.Conv2D(
        filters=32, kernel_size=3, strides=(2, 2), activation="relu"
    ),
    tf.keras.layers.Conv2D(
        filters=64, kernel_size=3, strides=(2, 2), activation="relu"
    ),
    tf.keras.layers.Conv2D(
        filters=128, kernel_size=3, strides=(2, 2), activation="relu"
    ),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(units=512, activation="relu"),
    tf.keras.layers.Dense(units=n_components),
])

In [10]:
decoder = tf.keras.Sequential([
    tf.keras.layers.InputLayer(input_shape=(n_components)),
    tf.keras.layers.Dense(units=512, activation="relu"),
    tf.keras.layers.Dense(units=4 * 4 * 128, activation="relu"),
    tf.keras.layers.Reshape(target_shape=(4, 4, 128)),
    tf.keras.layers.Conv2DTranspose(
        filters=128, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Conv2DTranspose(
        filters=64, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Conv2DTranspose(
        filters=32, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Conv2DTranspose(
        filters=3, kernel_size=3, strides=(1, 1), padding="SAME", activation="sigmoid"
    )
])

### Create model and train

In [11]:
embedder = tfUMAP(
    direct_embedding=False,
    verbose=True,
    negative_sample_rate=5,
    training_epochs=5,
    encoder=encoder,
    decoding_method="network", 
    decoder=decoder,
    valid_X = X_valid,
    valid_Y = Y_valid,
    dims = dims
)

In [None]:
z = embedder.fit_transform(X_train_flat)

tfUMAP(decoder=<tensorflow.python.keras.engine.sequential.Sequential object at 0x7f6d71652f98>,
       decoding_method='network', dims=(32, 32, 3),
       encoder=<tensorflow.python.keras.engine.sequential.Sequential object at 0x7f6d71497400>,
       negative_sample_rate=5,
       optimizer=<tensorflow.python.keras.optimizer_v2.adam.Adam object at 0x7f6d70f3d6d8>,
       tensorboard_logdir='/tmp/ten...
       [0.45490196, 0.2784314 , 0.10196079, ..., 0.14901961, 0.07450981,
        0.01960784],
       ...,
       [0.13725491, 0.69803923, 0.92156863, ..., 0.04705882, 0.12156863,
        0.19607843],
       [0.7411765 , 0.827451  , 0.9411765 , ..., 0.7647059 , 0.74509805,
        0.67058825],
       [0.8980392 , 0.8980392 , 0.9372549 , ..., 0.6392157 , 0.6392157 ,
        0.6313726 ]], dtype=float32),
       valid_Y=array([[1],
       [8],
       [5],
       ...,
       [9],
       [1],
       [1]], dtype=uint8))
Construct fuzzy simplicial set
Thu Jul 16 12:46:43 2020 Finding Nearest Nei

HBox(children=(IntProgress(value=0, description='epoch', max=5, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='batch', max=7121, style=ProgressStyle(description_width='init…

### Plot model output

In [None]:
fig, ax = plt.subplots( figsize=(8, 8))
sc = ax.scatter(
    z[:, 0],
    z[:, 1],
    c=Y_train.flatten(),
    cmap="tab10",
    s=0.1,
    alpha=0.5,
    rasterized=True,
)
ax.axis('equal')
ax.set_title("UMAP in Tensorflow embedding", fontsize=20)
plt.colorbar(sc, ax=ax);

### View loss

In [None]:
from tfumap.umap import retrieve_tensors
import seaborn as sns

In [None]:
loss_df = retrieve_tensors(embedder.tensorboard_logdir)
loss_df[:3]

In [None]:
loss_df.group.unique()

In [None]:
fig, axs = plt.subplots(ncols=2, figsize=(20,5))
ax = axs[0]
sns.lineplot(x="step", y="val", hue="group", data=loss_df[loss_df.variable=='umap_loss'], ax = ax)
ax.set_xscale('log')
ax.set_title('UMAP loss')

ax = axs[1]
sns.lineplot(x="step", y="val", hue="group", data=loss_df[loss_df.variable=='recon_loss'], ax = ax)
ax.set_xscale('log')
ax.set_title('Reconstruction loss')

### Save output

In [None]:
from tfumap.paths import ensure_dir, MODEL_DIR

In [None]:
output_dir = MODEL_DIR/'projections'/ 'cifar10'/ '64' / 'recon-network'
ensure_dir(output_dir)

In [None]:
embedder.save(output_dir)

In [None]:
loss_df.to_pickle(output_dir / 'loss_df.pickle')

In [None]:
np.save(output_dir / 'z.npy', z)

### View reconstructions on test data

In [None]:
z_test = embedder.transform(X_test)

In [None]:
X_test_recon = embedder.inverse_transform(z_test)

In [None]:
nex = 10
fig, axs = plt.subplots(ncols=nex, nrows=2, figsize=(3*nex,3*2))
for i in range(nex):
    axs[0,i].imshow(X_test[i].reshape(32,32,3))
    axs[1,i].imshow(X_test_recon[i].reshape(32,32,3))
for ax in axs.flatten():
    ax.axis('off')