In [11]:
import anndata
import scanpy as sc
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import pandas as pd
import umap
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from keras.callbacks import EarlyStopping

adata = sc.read('/tmp/work/RCproject_code/sce_export.h5ad')



In [12]:
# generate numerical values for each batch category
from sklearn.preprocessing import LabelEncoder
# set up categories variable
categories = adata.obs['batch']
# Create a LabelEncoder instance
label_encoder = LabelEncoder()
# Fit and transform the categories to integers
numerical_categories = label_encoder.fit_transform(categories)
numerical_categories.shape

(450,)

In [13]:
#normalizaiton
gene_expression_data = adata.layers['logcounts']
from sklearn.preprocessing import MinMaxScaler
# Min-max normalization
scaler = MinMaxScaler()
gene_expression_data = scaler.fit_transform(gene_expression_data)
number_samples = adata.shape[0]
number_genes = adata.shape[1]
input_dim = number_genes

In [14]:
# input_shape = gene_expression_data[0, :].shape  # This should be (12165,) if that's the feature count
# encoding_dim = 64  # This is the size of the encoded representation

input_shape = (12165,)  # Set to match your actual data
encoding_dim = 64  # Output dimension

In [32]:
from keras import layers, models

# Define the input shape
input_shape = (12165,)  # Number of genes

# Define the encoder function
def build_encoder():
    model = models.Sequential()
    model.add(layers.Input(shape=(12165,)))  # Set to match the input shape of 12165
    model.add(layers.Dense(256, activation='linear'))
    model.add(layers.LeakyReLU())
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(128, activation='linear'))
    model.add(layers.LeakyReLU())
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(64, activation='linear'))  # Encoded representation
    return model

# Define the decoder function (adjusted as discussed)
def build_decoder():    
    model = models.Sequential()
    model.add(layers.Input(shape=(64,)))  # Input shape should match the output of the encoder
    model.add(layers.Dense(128, activation='linear'))  
    model.add(layers.LeakyReLU())    
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(256, activation='linear'))  
    model.add(layers.LeakyReLU())    
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(12165, activation='sigmoid'))  # Output layer should match the input shape of the original data
    return model

In [25]:
# Define the domain classifier function
def build_domain_classifier(input_shape, num_domains):
    model = models.Sequential()    
    model.add(layers.Input(shape=input_shape))    
    model.add(layers.Dense(128, activation=None))
    model.add(layers.LeakyReLU()) 
    model.add(layers.BatchNormalization())    
    model.add(layers.Dense(64, activation=None))
    model.add(layers.LeakyReLU()) 
    model.add(layers.BatchNormalization())    
    model.add(layers.Dense(32, activation=None))
    model.add(layers.LeakyReLU())
    model.add(layers.BatchNormalization())    
    model.add(layers.Dense(num_domains, activation='softmax'))  # num_domains is the number of classes
    return model

In [33]:
# establishes a gradient reversal class

class GradientReversalLayer(tf.keras.layers.Layer):
    def __init__(self, lambda_value=1.0, **kwargs):
        self.lambda_value = lambda_value
        super(GradientReversalLayer, self).__init__(**kwargs)

    def call(self, inputs):
        return inputs

    def get_config(self):
        config = super().get_config()
        config.update({"lambda_value": self.lambda_value})
        return config

    def compute_gradient(self, inputs):
        return -self.lambda_value * inputs

In [57]:
# Define the encoder, decoder, and discriminator (assuming these functions are defined as in previous examples)
encoder = build_encoder()  # Assuming build_encoder() is defined
decoder = build_decoder()  # Assuming build_decoder() is defined
discriminator = build_domain_classifier((64,), 9)  # Assuming build_domain_classifier() is defined

# Optimizers for each model
encoder_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
decoder_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# Assuming you have 1000 training samples as an example
num_samples = 1000  # Replace this with your actual number of samples
batch_size = 32

# Calculate the number of steps per epoch
num_steps_per_epoch = num_samples // batch_size

# Compile the discriminator
discriminator.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Create a combined model for the encoder and decoder if training together
combined_model = models.Sequential([encoder, decoder])

# Compile combined model
combined_model.compile(optimizer='adam', loss='mean_squared_error')  # Use an appropriate loss for your case

# Training loop
num_epochs = 100

In [None]:
# this wasn't actually work with real data, need to replace.... 

In [58]:
for epoch in range(num_epochs):
    for step in range(num_steps_per_epoch):
        # Get a batch of data
        X_batch = np.random.rand(batch_size, 12165)  # Replace with your actual data
        y_true = np.random.randint(0, 9, size=(batch_size,))  # Replace with actual domain labels
        y_true = tf.keras.utils.to_categorical(y_true, num_classes=9)  # Convert labels to categorical

        with tf.GradientTape(persistent=True) as tape:
            # Encode
            encoded_output = encoder(X_batch, training=True)
            # Decode
            decoded_output = decoder(encoded_output, training=True)

            # Domain output
            domain_output = discriminator(encoded_output, training=True)

            # Calculate losses
            reconstruction_loss = tf.reduce_mean(tf.square(X_batch - decoded_output))  # Scalar
            domain_loss = tf.keras.losses.categorical_crossentropy(y_true, domain_output)  # Shape: (batch_size, num_classes)
            domain_loss = tf.reduce_mean(domain_loss)  # Average over the batch

            # Total loss
            total_loss = reconstruction_loss + domain_loss

        # Calculate gradients and update weights
        encoder_gradients = tape.gradient(total_loss, encoder.trainable_variables)
        decoder_gradients = tape.gradient(total_loss, decoder.trainable_variables)
        discriminator_gradients = tape.gradient(domain_loss, discriminator.trainable_variables)

        # Apply gradients
        encoder_optimizer.apply_gradients(zip(encoder_gradients, encoder.trainable_variables))
        decoder_optimizer.apply_gradients(zip(decoder_gradients, decoder.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

    # Print out scalar loss values
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss.numpy()}, Reconstruction Loss: {reconstruction_loss.numpy()}, Domain Loss: {domain_loss.numpy()}")

Epoch 1/100, Loss: 2.405923366546631, Reconstruction Loss: 0.09312045574188232, Domain Loss: 2.312802791595459
Epoch 2/100, Loss: 2.3547680377960205, Reconstruction Loss: 0.08487919718027115, Domain Loss: 2.2698888778686523
Epoch 3/100, Loss: 2.346843957901001, Reconstruction Loss: 0.08409430086612701, Domain Loss: 2.262749671936035
Epoch 4/100, Loss: 2.255073308944702, Reconstruction Loss: 0.08387571573257446, Domain Loss: 2.1711976528167725
Epoch 5/100, Loss: 2.3037197589874268, Reconstruction Loss: 0.0839821994304657, Domain Loss: 2.2197375297546387
Epoch 6/100, Loss: 2.278475284576416, Reconstruction Loss: 0.08368973433971405, Domain Loss: 2.1947855949401855
Epoch 7/100, Loss: 2.3167102336883545, Reconstruction Loss: 0.08371131867170334, Domain Loss: 2.232998847961426
Epoch 8/100, Loss: 2.2950069904327393, Reconstruction Loss: 0.08376150578260422, Domain Loss: 2.211245536804199
Epoch 9/100, Loss: 2.2754087448120117, Reconstruction Loss: 0.0836806669831276, Domain Loss: 2.1917281150

KeyboardInterrupt: 

In [60]:
y_true.shape

(64, 9)