In [1]:
import tensorflow as tf

ModuleNotFoundError: No module named 'tensorflow'

In [3]:
input_shape = (512, 64, 1)
conv_filters=(512, 256, 128, 64, 32)
conv_kernels=(3, 3, 3, 3, 3)
conv_strides=(2, 2, 2, 2, (2,1))
vector_dimension = 64
latent_space_dim = vector_dimension

In [None]:
class CVAE(tf.keras.Model):

  def __init__(self, latent_dim):
    super(CVAE, self).__init__()
    self.latent_dim = latent_dim
    self.e1 = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=input_shape),
            tf.keras.layers.Conv2D(
                filters=conv_filters, kernel_size=conv_kernels, strides=conv_strides, activation='relu'),
            tf.keras.layers.Conv2D(
                filters=conv_filters, kernel_size=conv_kernels, strides=conv_strides, activation='relu'),
            tf.keras.layers.Flatten(),
            # No activation
            tf.keras.layers.Dense(latent_space_dim + latent_space_dim),
        ]
    )
    self.e2 = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=input_shape),
            tf.keras.layers.Conv2D(
                filters=conv_filters, kernel_size=conv_kernels, strides=conv_strides, activation='relu'),
            tf.keras.layers.Conv2D(
                filters=conv_filters, kernel_size=conv_kernels, strides=conv_filters, activation='relu'),
            tf.keras.layers.Flatten(),
            # No activation
            tf.keras.layers.Dense(latent_space_dim + latent_space_dim),
        ]
    )

    self.encode = tf.keras.layers.concatenate([self.e1, self.e2], axis=1)

    self.d1 = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(latent_space_dim,)),
            tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),
            #tf.keras.layers.Reshape(target_shape=(7, 7, 32)),  # TODO
            tf.keras.layers.Conv2DTranspose(
                filters=conv_filters, kernel_size=conv_kernels, strides=conv_filters, padding='same',
                activation='relu'),
            tf.keras.layers.Conv2DTranspose(
                filters=conv_filters, kernel_size=conv_kernels, strides=conv_filters, padding='same',
                activation='relu'),
            # No activation
            tf.keras.layers.Conv2DTranspose(
                filters=conv_filters, kernel_size=conv_kernels, strides=conv_filters, padding='same'),
        ]
    )

    self.d2 = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(latent_space_dim,)),
            tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),
            #tf.keras.layers.Reshape(target_shape=(7, 7, 32)),  # TODO
            tf.keras.layers.Conv2DTranspose(
                filters=conv_filters, kernel_size=conv_kernels, strides=conv_filters, padding='same',
                activation='relu'),
            tf.keras.layers.Conv2DTranspose(
                filters=conv_filters, kernel_size=conv_kernels, strides=conv_filters, padding='same',
                activation='relu'),
            # No activation
            tf.keras.layers.Conv2DTranspose(
                filters=conv_filters, kernel_size=conv_kernels, strides=conv_filters, padding='same'),
        ]
    )

    self.decoder = tf.keras.layers.concatenate([self.d1, self.d2], axis=1) # combine les layers en paralllèle,
                                                                           # mais pas ce qu'on veut finalement

  @tf.function
  def sample_z1(self, eps=None):
    if eps is None:
      eps = tf.random.normal(shape=(100, self.latent_dim))
    return self.decode_z1(eps, apply_sigmoid=True)

  @tf.function
  def sample_z2(self, eps=None):
    if eps is None:
      eps = tf.random.normal(shape=(100, self.latent_dim))
    return self.decode_z2(eps, apply_sigmoid=True)

  @tf.function
  def sample(self, eps=None):
    if eps is None:
      eps = tf.random.normal(shape=(100, self.latent_dim))
    return self.sample_z1(), self.sample_z2()

  def encode_z1(self, x):
    mean, logvar = tf.split(self.e1(x), num_or_size_splits=2, axis=1)
    return mean, logvar

  def encode_z2(self, x):
    mean, logvar = tf.split(self.e2(x), num_or_size_splits=2, axis=1)
    return mean, logvar

  def encode(self, x):
    return self.encode_z1(x), self.encode_z2(x)

  def reparameterize(self, mean, logvar):
    eps = tf.random.normal(shape=mean.shape)
    return eps * tf.exp(logvar * .5) + mean

  def decode_z1(self, z1, apply_sigmoid=False):
    logits = self.d1(z1)
    if apply_sigmoid:
      probs = tf.sigmoid(logits)
      return probs
    return logits

  def decode_z2(self, z2, apply_sigmoid=False):
    logits = self.d2(z2)
    if apply_sigmoid:
      probs = tf.sigmoid(logits)
      return probs
    return logits

  def decode(self, z1, z2):
    logits1 = self.d1(z1)
    logits2 = self.d2(z2)
    return logits1 + logits2