<a href="https://colab.research.google.com/github/spour/VAEs_DNAseqs/blob/main/DNA_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import tensorflow.keras as keras
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Add, BatchNormalization, Activation, ZeroPadding2D, LeakyReLU, UpSampling1D, Conv2D, Layer, Conv1D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam

class Normalize(Layer):
  # This line defines the Normalize class, which subclasses the Layer class. 
  # This makes the InstanceNormalization class a custom layer that can be used 
  # in a Keras model.
  # y = scale * (x - mean) / sqrt(variance + epsilon) + offset
  def __init__(self, **kwargs):
    #This line defines the __init__ method of the InstanceNormalization class. 
    #The __init__ method is called when an instance of the class is created, 
    #and it allows us to define any initialization logic for the instance. 
    #In this case, we pass any keyword arguments through to the parent class's 
    #__init__ method using the super function. The __init__ function of the 
    # SubClass calls the __init__ function of the BaseClass using the super 
    #function, which allows the SubClass to inherit the x variable from the 
    #BaseClass.
    super(Normalize, self).__init__(**kwargs)

  def build(self, input_shape):
    self.scale = self.add_weight(name = "scale", shape=input_shape[-1:], initializer='ones', trainable = True)
    self.offset = self.add_weight(name = "offset", shape = input_shape[-1:], initializer = 'zeros', trainable = True)
    super(Normalize, self).build(input_shape)

  def call(self, inputs, training = None):
    mean, variance = tf.nn.moments(inputs, axes=[1, 2], keepdims=True)
    normalized = (inputs - mean) / tf.sqrt(variance + 1e-8)
    return self.scale * normalized + self.offset

class ThresholdedData(Layer):
    def __init__(self, **kwargs):
        super(ThresholdedData, self).__init__(**kwargs)

    def call(self, inputs, **kwargs):
        gen_preds = inputs
        max_value = tf.math.reduce_max(gen_preds, axis=2, keepdims = True)
        condition = tf.cast(tf.math.equal(gen_preds, max_value), tf.bool)
        thresholded_data = tf.where(condition, 1, 0)
        return thresholded_data

class Generator:
    def __init__(self, input_shape, output_shape = None):
        self.input_shape = input_shape
        self.output_shape = output_shape
        self.model = self.build_model()
        
    def resnet_blocks(self, input, filters, kernel_size=(7), strides=(1), use_dropout = False):
        x = Conv1D(filters=filters,
                   kernel_size=kernel_size,
                   padding = 'same',
                   strides=strides)(input)
        x = Normalize()(x)
        x = Activation('relu')(x)
        if use_dropout:
            x = Dropout(0.5)(x)
        x = Conv1D(filters=filters,
                    kernel_size=kernel_size,
                    padding = 'same',
                    strides=strides,)(x)
        x = Normalize()(x)
        skipped = Add()([input, x])
        return skipped

    def build_model(self):
        """build generator architecture"""
        inputs = Input(shape = self.input_shape)
        x = Conv1D(filters=64, kernel_size=7, strides=1, padding='same')(inputs)
        #(batch_size, 413, 12, 64),
        x = Normalize()(x)
        x = Activation('relu')(x)

        x = Conv1D(filters=128, kernel_size=3, strides=2, padding='same')(x)
        #(batch size, 207, 6, 128)
        x = Normalize()(x)
        x = Activation('relu')(x)

        x = Conv1D(filters=256, kernel_size=3, strides=2, padding='same')(x)
        #(batch size, 104, 3, 256)
        x = Normalize()(x)
        x = Activation('relu')(x)
        for i in range(9):
            x = self.resnet_blocks(x, 256, use_dropout=True)
        #(batch size, 104, 3, 256)

        x = UpSampling1D((2))(x)


        x = Conv1D(filters=128, kernel_size=3, strides=1, padding='same')(x)
        x = Normalize()(x)
        x = Activation('relu')(x)

        x = UpSampling1D((2))(x)
        x = Conv1D(filters=64, kernel_size=3, strides=1, padding='same')(x)

        x = Normalize()(x)
        x = Activation('relu')(x)

        x = Conv1D(filters=1, kernel_size=7, padding='same')(x)
        
        # x = ThresholdedData()(x)

        x = Activation('tanh')(x)
        outputs = Add()([x, inputs])
        model = Model(inputs=inputs, outputs=outputs, name='Generator')

        return model
    
    def call(self, x):
        output = self.model(x)
        # max_value = tf.math.reduce_max(output, axis=2, keepdims = True)
        # condition = tf.cast(tf.math.equal(output, max_value), tf.bool)
        # thresholded_data = tf.where(condition, 1, 0)
        # thresholded_data = Lambda(lambda x: tf.cast(x, tf.float32))(thresholded_data)
      
        return output


In [None]:
# # Define the size of the data
# data_size = (100,4)

# # Create a generator object
# generator = Generator(input_shape=(100,), output_shape=data_size)

# # Generate new data using the generator
# random_noise = np.random.rand(1, 10, 4)
# generated_data = generator.model.predict(random_noise)
# # Threshold the generated data
# # Get the maximum value of each row
# max_value = np.amax(generated_data, axis=2)
# # broadcast max_value to the same shape as the generated data
# max_value = max_value[..., None]
# # Compare the max_value with the generated_data
# thresholded_data = np.where(generated_data == max_value, 1, 0)


gen = Generator((12, 4))
np.random.rand(10, 4).shape
# gen.model
gen_preds = gen.model.predict(np.random.rand(5, 12, 4))
gen_preds, tf.keras.backend.one_hot(tf.argmax(gen_preds, axis=-1), 4)



(array([[[ 4.33992445e-01,  3.14303041e-02,  4.46219742e-01,
          -2.51081139e-02],
         [-6.99107349e-02,  4.36770260e-01,  9.50753689e-04,
           1.23276770e-01],
         [ 5.60496509e-01,  4.25479412e-01,  2.03922480e-01,
          -2.24216878e-01],
         [ 1.48257971e+00,  5.96048892e-01,  1.30546606e+00,
           1.05171609e+00],
         [ 1.66187024e+00,  1.67680573e+00,  1.53668177e+00,
           1.03697729e+00],
         [ 7.13052571e-01,  4.94369417e-01,  1.07618785e+00,
           8.45406234e-01],
         [ 7.65817165e-01,  1.34421384e+00,  7.58324265e-01,
           1.07932210e+00],
         [ 7.10618198e-01,  1.24024630e+00,  8.92472267e-01,
           1.30849218e+00],
         [ 5.39994538e-01,  1.02108389e-01,  4.06426013e-01,
           3.59333634e-01],
         [ 8.32704484e-01,  6.23886883e-02,  7.78951883e-01,
           3.91527116e-01],
         [ 5.74783027e-01,  7.91240752e-01,  4.91049707e-01,
           8.24134529e-01],
         [-1.42338693

In [None]:
def custom_loss(y_true, y_pred):
    # convert y_pred to one-hot encoded DNA sequence
    y_pred = tf.keras.backend.one_hot(tf.argmax(y_pred, axis=-1), 4)
    # compute cross-entropy loss between y_true and y_pred
    loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_true, y_pred))
    return loss

  

In [None]:
# define the training data
x_train = np.random.rand(5, 12, 4)
y_train = np.random.rand(5, 12, 4)

# compile generator with custom loss function
gen.model.compile(optimizer='adam', loss=custom_loss)

# train generator with custom loss function
gen.model.fit(x_train, y_train)

ValueError: ignored

In [None]:
gen.model.predict(np.random.rand(5, 12, 4))



array([[[ 1.13453001e-01,  8.78755569e-01,  6.90156519e-01,
          1.00904191e+00],
        [ 5.97126067e-01,  8.10518920e-01,  1.39585331e-01,
          4.40598190e-01],
        [ 4.20231462e-01,  5.77590287e-01,  5.50183654e-02,
          8.65154266e-02],
        [ 5.08240581e-01,  6.10596776e-01,  2.72514582e-01,
          7.61170089e-01],
        [ 8.18453610e-01,  4.32520777e-01,  5.80110073e-01,
          8.59706700e-01],
        [ 6.13993466e-01,  3.75480950e-01,  8.11111271e-01,
          2.12620482e-01],
        [ 1.19842219e+00,  3.93357813e-01,  9.14933920e-01,
          1.17848849e+00],
        [ 6.24297708e-02,  8.57748508e-01,  4.22832221e-01,
          1.03707957e+00],
        [ 2.42598116e-01, -3.88725728e-01,  1.82259500e-01,
          1.70160532e-01],
        [-8.59919786e-02, -4.78339732e-01,  9.02146697e-02,
          8.34679008e-02],
        [ 2.45593160e-01, -3.22344661e-01,  6.17783546e-01,
         -1.42424658e-01],
        [ 6.00635946e-01,  7.25187123e-01, 

In [None]:
gen.model.predict(np.random.rand(5, 12, 4))

In [None]:
np.random.choice([1, 0], size=10)

array([0, 0, 0, 0, 1, 0, 1, 1, 1, 1])

In [None]:
np.random.rand(10, 4).shape
# gen.model
gen_preds = gen.model.predict(np.random.rand(5, 12, 4))
gen_preds



array([[[0.95649457, 1.5792304 , 1.0432594 , 1.0850923 ],
        [1.5438049 , 0.96144783, 1.136067  , 1.7124841 ],
        [1.4994843 , 1.0132003 , 0.77493596, 1.0619184 ],
        [1.3604823 , 1.1891153 , 1.5660594 , 1.5014403 ],
        [0.8869077 , 1.0221043 , 1.6248938 , 0.82657355],
        [1.6517781 , 1.5818973 , 1.0305252 , 1.6367309 ],
        [1.3927432 , 1.1659665 , 0.8729971 , 0.85839224],
        [0.9588752 , 1.6286709 , 0.86501026, 1.0427992 ],
        [1.7081871 , 1.3937707 , 1.0574489 , 0.8925362 ],
        [1.6213205 , 1.4511935 , 1.651908  , 1.285619  ],
        [1.7595004 , 0.8323757 , 1.2690752 , 1.2173231 ],
        [0.9759721 , 1.5377301 , 0.9502256 , 1.7309215 ]],

       [[1.2288176 , 0.9360242 , 1.4289131 , 0.88950294],
        [1.4922442 , 1.3282952 , 1.2709968 , 1.7420428 ],
        [1.7250664 , 0.9172133 , 1.3491843 , 1.6712048 ],
        [1.183441  , 1.5954129 , 0.9576025 , 1.5941265 ],
        [1.5062143 , 1.0914546 , 1.5714425 , 1.2006596 ],
        [0.9

In [None]:
gen_preds

array([[[1.3699703 , 1.6136183 , 1.0059911 , 1.0015804 ],
        [1.5101634 , 0.9045139 , 1.4829825 , 1.4254487 ],
        [1.271283  , 1.5862303 , 1.2953815 , 1.519671  ],
        [1.4197628 , 1.1679344 , 1.6618164 , 1.1136235 ],
        [0.8086056 , 1.2764151 , 0.8432677 , 1.4145457 ],
        [1.2721388 , 1.7566196 , 1.4759827 , 1.5125242 ],
        [1.5442383 , 0.7729102 , 1.034459  , 1.2865183 ],
        [1.0363944 , 0.91669935, 1.7156918 , 1.169918  ],
        [1.4007537 , 0.80221903, 0.78630257, 1.0212091 ],
        [1.5649639 , 1.5522517 , 1.1897322 , 1.3408498 ],
        [1.5464742 , 0.88783324, 1.6012317 , 1.5146523 ],
        [0.7846871 , 1.0504107 , 1.1544323 , 1.5562268 ]],

       [[1.3304642 , 1.3249049 , 1.5507174 , 1.6953934 ],
        [1.3687134 , 1.4138108 , 1.5976247 , 1.7156005 ],
        [1.7102622 , 1.1496952 , 1.576879  , 0.76611906],
        [1.7044578 , 1.7199925 , 1.3303728 , 1.3985105 ],
        [1.0948073 , 0.93318766, 1.1050471 , 0.86061126],
        [1.4

In [None]:
# Generate new data using the generator
# random_noise = np.random.rand(1, 10, 4)
# generated_data = generator.model.predict(random_noise)
# Threshold the generated data
# Get the maximum value of each row
gen_preds = gen.model.predict(np.random.rand(5, 12, 4))
max_value = np.amax(gen_preds, axis=2)
# broadcast max_value to the same shape as the generated data
max_value = max_value[..., None]
# Compare the max_value with the generated_data
thresholded_data = np.where(gen_preds == max_value, 1, 0)

In [None]:
thresholded_data

array([[[1, 0, 0, 0],
        [0, 1, 0, 0],
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0],
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, 1, 0, 0],
        [1, 0, 0, 0],
        [0, 0, 1, 0]],

       [[1, 0, 0, 0],
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 1, 0, 0],
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 1]],

       [[0, 1, 0, 0],
        [1, 0, 0, 0],
        [0, 0, 0, 1],
        [0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [0, 0, 1, 0]],

       [[0, 0, 1, 0],
        [0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1],
    