In [None]:
import keras
from keras.layers import *
#from keras.layers import BatchNormalization
from keras.models import Model
from keras.datasets import mnist
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()

In [None]:
cellData = np.load("/content/drive/MyDrive/cellData.npy")

In [None]:
from sklearn.model_selection import train_test_split
trainData, testData = train_test_split(cellData, test_size = 0.25, random_state=1)

In [None]:
img_width  = trainData.shape[1]
img_height = trainData.shape[2]
num_channels = 30

x_train = trainData.reshape(trainData.shape[0], img_height, img_width, num_channels)
x_test = testData.reshape(testData.shape[0], img_height, img_width, num_channels)
input_shape = (img_height, img_width, num_channels)

In [None]:
plt.figure(1)
plt.subplot(221)
plt.imshow(x_train[42][:,:,0], cmap='gray')

plt.subplot(222)
plt.imshow(x_train[42][:,:,1], cmap='gray')

plt.subplot(223)
plt.imshow(x_train[42][:,:,2], cmap='gray')

plt.subplot(224)
plt.imshow(x_train[42][:,:,3], cmap='gray')
plt.show()

In [None]:
latent_dim = 2 # Number of Distributions we want the input represneted as

input_img = Input(shape=input_shape, name='encoder_input')
x = Conv2D(32, 3, padding='same', activation='relu')(input_img)
x = Conv2D(64, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2D(128, 3, padding='same', activation='relu')(x)
x = Conv2D(512, 3, padding='same', activation='relu')(x)

conv_shape = K.int_shape(x)
x = Flatten()(x)
x = Dense(32, activation='relu')(x)

In [None]:
z_mu = Dense(latent_dim, name='latent_mu')(x)
z_sigma = Dense(latent_dim, name='latent_sigma')(x)

In [None]:
def sample_z(args):
  z_mu, z_sigma = args
  eps = K.random_normal(shape=(K.shape(z_mu)[0], K.int_shape(z_mu)[1]))
  return z_mu + K.exp(z_sigma / 2) * eps

In [None]:
z = Lambda(sample_z, output_shape=(latent_dim, ), name='z')([z_mu, z_sigma])

In [None]:
encoder = Model(input_img, [z_mu, z_sigma, z], name='encoder')
print(encoder.summary())

In [None]:
decoder_input = Input(shape=(latent_dim, ), name='decoder_input')


x = Dense(conv_shape[1]*conv_shape[2]*conv_shape[3], activation='relu')(decoder_input)
x = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
x = Conv2DTranspose(32, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(num_channels, 3, padding='same', activation='sigmoid', name='decoder_output')(x)

decoder = Model(decoder_input, x, name='decoder')
decoder.summary()

z_decoded = decoder(z)

In [None]:
class CustomLayer(keras.layers.Layer):

    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)

        recon_loss = keras.metrics.binary_crossentropy(x, z_decoded)

        kl_loss = -5e-4 * K.mean(1 + z_sigma - K.square(z_mu) - K.exp(z_sigma), axis=-1)
        return K.mean(recon_loss + kl_loss)

    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded)
        self.add_loss(loss, inputs=inputs)
        return z_decoded


y = CustomLayer()([input_img, z_decoded])

In [None]:
vae = Model(input_img, y, name='vae')
vae.compile(optimizer='adam', loss=None)
vae.summary()
history = vae.fit(x_train, None, epochs = 200, batch_size = 8184, validation_split = 0.2)

In [None]:
cellData = np.load("/content/drive/MyDrive/cellData_HTAN9_32_scene3.npy")
cellData2 = np.load("/content/drive/MyDrive/cellData_HTAN9_93_scene2.npy")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("dark")

z_mean, _, _ = encoder.predict(trainData)
z_mean_2, _, _ = encoder.predict(cellData)
z_mean_3, _, _ = encoder.predict(cellData2)


In [None]:
xMin = -0.06
xMax = 0.015
yMin = -0.150
yMax = 0.03

m1 = np.mean(z_mean, axis=0) #training data
m2 = np.mean(z_mean_2, axis=0) #test data 1
m3 = np.mean(z_mean_3, axis=0) # test data 2 (diff time)

sns.scatterplot(x=z_mean[:,0], y=z_mean[:,1], size=3, alpha = 0.2)
plt.plot([m1[0]], [m1[1]], 'o', color="Red")
plt.annotate("(%0.3f, %0.3f)" % (m1[0], m1[1]),
            xy=(m1[0], m1[1]),
            xytext=(0.7, 0.65),
            textcoords='figure fraction',
            arrowprops=dict(facecolor='black', shrink=0.05),
            horizontalalignment='right',
            verticalalignment='top',
      )

#plt.xlim([xMin,xMax])
#plt.ylim([yMin,yMax])
plt.xlabel("z [0]")
plt.ylabel("z [1]")

#plt.show()
#plt.savefig("/content/drive/MyDrive/Results/trained_figure.svg", dpi=300)

In [None]:
sns.scatterplot(x=z_mean_2[:,0], y=z_mean_2[:,1], size=3, alpha = 0.2)
plt.plot([m2[0]], [m2[1]], 'o', color="Red")
plt.annotate("(%0.3f, %0.3f)" % (m2[0], m2[1]),
            xy=(m2[0], m2[1]),  # theta, radius
            xytext=(0.7, 0.65),    # fraction, fraction
            textcoords='figure fraction',
            arrowprops=dict(facecolor='black', shrink=0.05),
            horizontalalignment='right',
            verticalalignment='top',
            )

#plt.xlim([xMin,xMax])
#plt.ylim([yMin,yMax])
plt.xlabel("z [0]")
plt.ylabel("z [1]")

#plt.show()
#plt.savefig("/content/drive/MyDrive/Results/timepoint1_figure.svg", dpi=300)

In [None]:
sns.scatterplot(x=z_mean_3[:,0], y=z_mean_3[:,1], size=3, alpha = 0.2)
plt.plot([m3[0]], [m3[1]], 'o', color="Red")
plt.annotate("(%0.3f, %0.3f)" % (m3[0], m3[1]),
            xy=(m3[0], m3[1]),  # theta, radius
            xytext=(0.7, 0.65),    # fraction, fraction
            textcoords='figure fraction',
            arrowprops=dict(facecolor='black', shrink=0.05),
            horizontalalignment='right',
            verticalalignment='top',
            )

#plt.xlim([xMin,xMax])
#plt.ylim([yMin,yMax])
plt.xlabel("z [0]")
plt.ylabel("z [1]")

#plt.show()
#plt.savefig("/content/drive/MyDrive/Results/timepoint2_figure.svg", dpi=300)

In [None]:
m1 = np.mean(z_mean, axis=0) #training data
m2 = np.mean(z_mean_2, axis=0) #test data 1
m3 = np.mean(z_mean_3, axis=0) # test data 2 (diff time)

print(m1, m2, m3)

In [None]:
n = 20
figure = np.zeros((img_width * n, img_height * n))

grid_x = np.linspace(-2, 3, n)
grid_y = np.linspace(-2, 2, n)[::-1]

for i, yi in enumerate(grid_y):
    for j, xi in enumerate(grid_x):
        z_sample = np.array([[xi, yi]])
        x_decoded = decoder.predict(z_sample)
        digit = x_decoded[0].reshape(img_width, img_height, num_channels)[:, :, 1]
        figure[i * img_width: (i + 1) * img_width,
               j * img_height: (j + 1) * img_height] = digit

plt.figure(figsize=(10, 10))
fig_shape = np.shape(figure)
figure = figure.reshape((fig_shape[0], fig_shape[1]))

plt.imshow(figure, cmap='gray')
plt.show()

In [None]:
import numpy as np

def uneven_kl_divergence(pk,qk):
    #print(np.min(pk), np.min(qk))
    pk -= np.min([np.min(pk), np.min(qk)]) - 1e-5
    qk -= np.min([np.min(pk), np.min(qk)]) - 1e-5

    if len(pk)>len(qk):
        pk = np.random.choice(pk,len(qk))
    elif len(qk)>len(pk):
        qk = np.random.choice(qk,len(pk))

    #print(np.min(pk), np.min(qk))
    return np.sum(pk * np.log(pk/qk))

In [None]:
sx = uneven_kl_divergence(z_mean[:, 0].copy(), z_mean_2[:, 0].copy())
sx2 = uneven_kl_divergence(z_mean[:, 0].copy(), z_mean_3[:, 0].copy())
sx3 = uneven_kl_divergence(z_mean_2[:, 0].copy(), z_mean_3[:, 0].copy())

#print(z_mean[1], z_mean_2[1], z_mean_3[1])

sy = uneven_kl_divergence(z_mean[:, 1].copy(), z_mean_2[:, 1].copy())
sy2 = uneven_kl_divergence(z_mean[:, 1].copy(), z_mean_3[:, 1].copy())
sy3 = uneven_kl_divergence(z_mean_2[:, 1].copy(), z_mean_3[:, 1].copy())

print(sx, sx2, sx3)
print(sy, sy2, sy3)

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(x_test[1][:,:,18])

plt.subplot(1, 2, 2)
result = vae.predict(x_test[1].reshape(1, 24, 24, 30))
plt.imshow(result[0, :, :, 18])

#plt.show()
#plt.savefig("/content/drive/MyDrive/Results/RebuiltImage.svg", dpi=300)

In [None]:
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
#plt.show()
#plt.savefig("/content/drive/MyDrive/Results/ModelLoss.svg", dpi=300)