<a href="https://colab.research.google.com/github/udayameister/Connectome/blob/main/Fake_matrix_using_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers


In [8]:
# Load real brain connectivity matrices from CSV
real_data = pd.read_csv('real_data.csv', header=None)
real_data = real_data.values  # Convert DataFrame to numpy array
real_data = real_data.reshape(-1, 90, 90)  # Reshape to 90x90 matrices


In [9]:
# Define generator model
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(256, input_shape=(100,)))
    model.add(layers.LeakyReLU())
    model.add(layers.Dense(90*90, activation='tanh'))
    model.add(layers.Reshape((90, 90)))
    return model

In [10]:
# Define discriminator model
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Flatten(input_shape=(90, 90)))
    model.add(layers.Dense(256))
    model.add(layers.LeakyReLU())
    model.add(layers.Dense(1))
    return model


In [11]:
# Define GAN model
def make_gan_model(generator, discriminator):
    discriminator.trainable = False
    model = tf.keras.Sequential([generator, discriminator])
    return model


In [12]:
# Create generator and discriminator models
generator = make_generator_model()
discriminator = make_discriminator_model()


In [13]:
# Create GAN model
gan = make_gan_model(generator, discriminator)


In [14]:
# Define loss function
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# Compile the models
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [15]:
@tf.function
def train_step(real_data, batch_size):
    # Generate random noise as input to the generator
    noise = tf.random.normal([batch_size, 100])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Generate fake connectivity matrices using the generator
        fake_data = generator(noise, training=True)
        
        # Concatenate real and fake data
        combined_data = tf.concat([real_data, fake_data], axis=0)
        
        # Create labels for real and fake data
        labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)
        
        # Train the discriminator
        predictions = discriminator(combined_data, training=True)
        disc_loss = loss_fn(labels, predictions)
        
        # Train the generator
        gen_loss = loss_fn(tf.ones((batch_size, 1)), predictions[:batch_size])
        
    # Calculate gradients
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    # Apply gradients
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    

In [16]:
# Training loop
def train(dataset, epochs, batch_size):
    for epoch in range(epochs):
        for batch in dataset:
            train_step(batch, batch_size)


In [17]:
# Preprocess the real data
real_data = (real_data - np.min(real_data)) / (np.max(real_data) - np.min(real_data))


In [20]:
# Reshape real data for training
real_data = real_data.reshape(real_data.shape[0], -1)

In [23]:
# Training parameters
epochs = 1000
batch_size = 32

In [24]:
dataset = tf.data.Dataset.from_tensor_slices(real_data).batch(batch_size)

In [25]:
num_fake_matrices = 10
noise = tf.random.normal([num_fake_matrices, 100])
fake_matrices = generator(noise, training=False)

In [26]:
fake_matrices = (fake_matrices * (np.max(real_data) - np.min(real_data))) + np.min(real_data)
fake_matrices = fake_matrices.numpy().reshape(num_fake_matrices, 90, 90)

In [29]:
for i, matrix in enumerate(fake_matrices):
  pd.DataFrame(matrix).to_csv(f'fake_matrix_{i+1}.csv', index=False, header=False)