In [83]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import scanpy as sc
import anndata as ad
import scipy.sparse

from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
adata = sc.read_h5ad("blinded-epithelial-subtype-anndata.h5ad")


In [3]:
adata

AnnData object with n_obs × n_vars = 24469 × 4148
    obs: 'sample_batch', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'dataset', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'doublet_scores', 'predicted_doublets', 'n_counts', 'n_genes', 'S_score', 'G2M_score', 'sample description', 'patient', 'blinded_subtype'
    var: 'Accession', 'Chromosome', 'End', 'Start', 'Strand', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std'
    uns: 'Control_epi_anno_colors_lung', 'Control_epi_anno_lung_colors', 'Control_epi_anno_nasal_colors', 'Control_major_anno_colors', 'dendrogram_leiden_0.1', 'eCRSwNP_epi_anno_colors_lung', 'eCRSwNP_epi_anno_lung_colors', 'eCRSwNP_epi_anno_nasal_colors', 'eCRSwNP_major_anno_colors', 'epi_anno_colors', 'epi_anno_lung_colors', 'epi_anno_nasal_colors', 'hvg', 'leiden', 'leiden_0.1_colors', 'leiden_0.2_colors', 'leiden_0.3

In [68]:
adata.obs['sample description'].unique()

['Polyp_NE', 'Polyp_E', 'Control_NE', 'Control_E']
Categories (4, object): ['Control_E', 'Control_NE', 'Polyp_E', 'Polyp_NE']

In [34]:
def load_and_inspect_data(adata):
    """
    Load the AnnData object and perform initial quality checks
    
    Parameters:
    -----------
    adata : AnnData
        Input single-cell RNA-seq dataset
    
    Returns:
    --------
    AnnData object with initial preprocessing information
    """
    # Basic dataset information
    print(f"Dataset shape: {adata.shape}")
    
    # Safely print observed variables
    print("Observed variables:")
    print(list(adata.obs.columns))
    
    # Safely print unstructured annotations
    print("\nUnstructured annotations:")
    print(list(adata.uns.keys()))
    
    # Layer information
    if hasattr(adata, 'layers'):
        print("\nAvailable layers:")
        for layer_name, layer_matrix in adata.layers.items():
            print(f"Layer: {layer_name}")
            print(f"  Shape: {layer_matrix.shape}")
            print(f"  Data type: {layer_matrix.dtype}")
            
            # Safe non-zero count
            try:
                non_zero = np.count_nonzero(layer_matrix)
                print(f"  Non-zero elements: {non_zero}")
            except Exception as e:
                print(f"  Could not count non-zero elements: {e}")
    
    return adata

In [35]:
data = load_and_inspect_data(adata)

Dataset shape: (24469, 4148)
Observed variables:
['sample_batch', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'dataset', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'doublet_scores', 'predicted_doublets', 'n_counts', 'n_genes', 'S_score', 'G2M_score', 'sample description', 'patient', 'blinded_subtype']

Unstructured annotations:
['Control_epi_anno_colors_lung', 'Control_epi_anno_lung_colors', 'Control_epi_anno_nasal_colors', 'Control_major_anno_colors', 'dendrogram_leiden_0.1', 'eCRSwNP_epi_anno_colors_lung', 'eCRSwNP_epi_anno_lung_colors', 'eCRSwNP_epi_anno_nasal_colors', 'eCRSwNP_major_anno_colors', 'epi_anno_colors', 'epi_anno_lung_colors', 'epi_anno_nasal_colors', 'hvg', 'leiden', 'leiden_0.1_colors', 'leiden_0.2_colors', 'leiden_0.3_colors', 'leiden_0.5_colors', 'leiden_0.8_colors', 'leiden_1.0_colors', 'leiden_2.0_colors', 'major_anno_colors', 'majority_vot

In [40]:
data.obs

Unnamed: 0,sample_batch,initial_size_unspliced,initial_size_spliced,initial_size,dataset,n_genes_by_counts,log1p_n_genes_by_counts,total_counts,log1p_total_counts,total_counts_mt,...,pct_counts_mt,doublet_scores,predicted_doublets,n_counts,n_genes,S_score,G2M_score,sample description,patient,blinded_subtype
AAACCTGCAGGA-1-0,ATGC-1,1551,1152,1152.0,neCRSwNP_2,202,8.544614,499.0,9.982622,0.0,...,0.0,0.105882,False,21646.0,5138,-0.138944,-0.137459,Polyp_NE,A,Epithelial_1
AAACCTGGTTAC-1-0,GTCA-1,84,77,77.0,neCRSwNP_2,23,6.453625,29.0,6.946976,0.0,...,0.0,0.081021,False,1039.0,634,-0.085944,-0.132128,Polyp_NE,A,Epithelial_1
AAACGGGAGTCA-1-0,ATAG-1,504,370,370.0,neCRSwNP_2,105,7.777374,299.0,8.765458,0.0,...,0.0,0.131222,False,6408.0,2385,-0.143451,-0.063642,Polyp_NE,A,Epithelial_4
AAACGGGCAGAC-1-0,AGGT-1,43,37,37.0,neCRSwNP_2,13,6.163315,29.0,7.293018,0.0,...,0.0,0.095588,False,1469.0,474,-0.034987,-0.018585,Polyp_NE,A,Epithelial_6
AAAGATGGTAGA-1-0,GGAA-1,1700,1119,1119.0,neCRSwNP_2,387,8.585039,695.0,9.808517,0.0,...,0.0,0.075936,False,18187.0,5350,-0.130134,-0.060174,Polyp_NE,A,Epithelial_3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGGTTCATGA-9-1,TCTG-1,325,349,349.0,Control_8,116,6.594413,139.0,6.926577,0.0,...,0.0,0.157303,False,1018.0,730,0.012019,-0.068563,Control_NE,J,Epithelial_1
TTTGGTTTCGCT-9-1,CATC-1,405,460,460.0,Control_8,131,6.752270,153.0,7.117206,0.0,...,0.0,0.203178,False,1232.0,855,-0.040392,-0.095224,Control_NE,J,Epithelial_1
TTTGTTGAGCTA-9-1,CTGT-1,1722,1455,1455.0,Control_8,294,7.736307,530.0,8.540323,0.0,...,0.0,0.078571,False,5116.0,2289,-0.012679,-0.060630,Control_NE,J,Epithelial_7
TTTGTTGCAGAT-9-1,TTCG-1,358,370,370.0,Control_8,144,6.612041,182.0,6.934397,0.0,...,0.0,0.141370,False,1026.0,743,-0.088261,-0.088655,Control_NE,J,Epithelial_1


In [36]:
def quality_control_with_layers(adata, 
                                layer='matrix', 
                                min_genes=200, 
                                max_genes=2500, 
                                max_mt_percent=10):
    """
    Perform quality control using matrix layer with robust filtering
    
    Parameters:
    -----------
    adata : AnnData
        Input single-cell RNA-seq dataset
    layer : str, optional (default='matrix')
        Layer to use for QC filtering
    
    Returns:
    --------
    Filtered AnnData object
    """
    # Ensure we're working with the correct layer
    if layer not in adata.layers:
        raise ValueError(f"Layer '{layer}' not found in AnnData object")
    
    # Use matrix layer for filtering
    adata.X = adata.layers[layer].copy()
    
    # Calculate mitochondrial gene percentage
    adata.var['mt'] = adata.var_names.str.startswith('MT-')
    sc.pp.calculate_qc_metrics(adata, 
                                qc_vars=['mt'], 
                                percent_top=None, 
                                log1p=False, 
                                inplace=True)
    
    # More robust filtering approach
    # Determine gene counts per cell
    if scipy.sparse.issparse(adata.X):
        gene_counts_per_cell = np.array((adata.X > 0).sum(axis=1)).flatten()
    else:
        gene_counts_per_cell = (adata.X > 0).sum(axis=1)
    
    # Create boolean masks for filtering
    genes_mask = (gene_counts_per_cell >= min_genes) & (gene_counts_per_cell <= max_genes)
    mt_mask = adata.obs['pct_counts_mt'] < max_mt_percent
    
    # Combine masks
    combined_mask = genes_mask & mt_mask
    
    # Apply filtering
    adata_filtered = adata[combined_mask, :]
    
    # Filter genes with low expression
    sc.pp.filter_genes(adata_filtered, min_cells=3)
    
    print(f"Original dataset shape: {adata.shape}")
    print(f"Filtered dataset shape: {adata_filtered.shape}")

    return adata_filtered
    

In [37]:
dta = quality_control_with_layers(adata)

  adata.var["n_cells"] = number


Original dataset shape: (24469, 4148)
Filtered dataset shape: (15532, 4131)


In [41]:
def normalize_layers(adata, 
                     layer='matrix', 
                     n_top_genes=2000):
    """
    Normalize data using specific layer
    
    Parameters:
    -----------
    adata : AnnData
        Input single-cell RNA-seq dataset
    layer : str, optional (default='matrix')
        Layer to use for normalization
    n_top_genes : int, optional (default=2000)
        Number of highly variable genes to select
    
    Returns:
    --------
    Normalized AnnData object
    """
    # Set the layer as primary matrix
    adata.X = adata.layers[layer].copy()
    
    # Normalize total
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    
    # Identify highly variable genes
    sc.pp.highly_variable_genes(adata, 
                                n_top_genes=n_top_genes, 
                                min_mean=0.0125, 
                                max_mean=3, 
                                min_disp=0.5)
    
    # Subset to highly variable genes
    adata = adata[:, adata.var['highly_variable']].copy()
    
    return adata

In [42]:
ndata = normalize_layers(dta)

In [48]:
def preprocess_layers(adata, layer='matrix'):
    """
    Complete preprocessing pipeline using layers
    
    Parameters:
    -----------
    adata : AnnData
        Raw single-cell RNA-seq dataset
    layer : str, optional (default='matrix')
        Layer to use for preprocessing
    
    Returns:
    --------
    Preprocessed AnnData object
    """
    # Inspect layers
    adata = load_and_inspect_data(adata)
    
    # Quality control using matrix layer
    adata = quality_control_with_layers(adata, layer=layer)
    
    # Normalize using matrix layer
    adata = normalize_layers(adata, layer=layer)
    
    # Dimensionality reduction
    sc.pp.pca(adata, n_comps=30)
    sc.pp.neighbors(adata)
    sc.tl.umap(adata)
    
    return adata

In [49]:
pdata = preprocess_layers(adata)

Dataset shape: (24469, 4148)
Observed variables:
['sample_batch', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'dataset', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'doublet_scores', 'predicted_doublets', 'n_counts', 'n_genes', 'S_score', 'G2M_score', 'sample description', 'patient', 'blinded_subtype']

Unstructured annotations:
['Control_epi_anno_colors_lung', 'Control_epi_anno_lung_colors', 'Control_epi_anno_nasal_colors', 'Control_major_anno_colors', 'dendrogram_leiden_0.1', 'eCRSwNP_epi_anno_colors_lung', 'eCRSwNP_epi_anno_lung_colors', 'eCRSwNP_epi_anno_nasal_colors', 'eCRSwNP_major_anno_colors', 'epi_anno_colors', 'epi_anno_lung_colors', 'epi_anno_nasal_colors', 'hvg', 'leiden', 'leiden_0.1_colors', 'leiden_0.2_colors', 'leiden_0.3_colors', 'leiden_0.5_colors', 'leiden_0.8_colors', 'leiden_1.0_colors', 'leiden_2.0_colors', 'major_anno_colors', 'majority_vot

  adata.var["n_cells"] = number


Original dataset shape: (24469, 4148)
Filtered dataset shape: (15532, 4131)


  from .autonotebook import tqdm as notebook_tqdm


In [64]:
pdata.obs.columns

Index(['sample_batch', 'initial_size_unspliced', 'initial_size_spliced',
       'initial_size', 'dataset', 'n_genes_by_counts',
       'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts',
       'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt',
       'doublet_scores', 'predicted_doublets', 'n_counts', 'n_genes',
       'S_score', 'G2M_score', 'sample description', 'patient',
       'blinded_subtype'],
      dtype='object')

In [65]:
pdata.obs['blinded_subtype']

AAACCTGCAGGA-1-0    Epithelial_1
AAAGATGGTAGA-1-0    Epithelial_3
AAAGTAGTCTAA-1-0    Epithelial_1
AAATGCCAGACG-1-0    Epithelial_1
AACCATGCACCG-1-0    Epithelial_3
                        ...     
TTTGATCTCGCG-9-1    Epithelial_9
TTTGGAGAGACA-9-1    Epithelial_6
TTTGGAGCATAG-9-1    Epithelial_1
TTTGTTGAGCTA-9-1    Epithelial_7
TTTGTTGGTTCC-9-1    Epithelial_7
Name: blinded_subtype, Length: 15532, dtype: category
Categories (9, object): ['Epithelial_1', 'Epithelial_2', 'Epithelial_3', 'Epithelial_4', ..., 'Epithelial_6', 'Epithelial_7', 'Epithelial_8', 'Epithelial_9']

In [58]:
print(pdata.X.toarray()[0])
print(pdata.var_names[0])

[0. 0. 0. ... 0. 0. 0.]
SAMD11


In [59]:
# Option 1: Convert X matrix (processed gene expression)
df_expression = pd.DataFrame(pdata.X.toarray(), 
                              columns=pdata.var_names, 
                              index=pdata.obs_names)

In [60]:
df_expression

Unnamed: 0,SAMD11,ISG15,MXRA8,AL645728.1,CFAP74,AL034417.4,PIK3CD,ANGPTL7,TNFRSF1B,KAZN-AS1,...,HS6ST2,FGF13,LINC00632,AL031073.2,BGN,PNCK,TKTL1,IKBKG,TMLHE-AS1,AC007325.4
AAACCTGCAGGA-1-0,0.000000,0.000000,0.0,4.396427,0.00000,0.0,0.0,0.0,0.0,0.0,...,0.0,0.000000,0.0,0.0,0.000000,3.046429,0.0,0.0,0.000000,0.0
AAAGATGGTAGA-1-0,0.000000,0.000000,0.0,3.393736,2.73362,0.0,0.0,0.0,0.0,0.0,...,0.0,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.0
AAAGTAGTCTAA-1-0,0.000000,0.000000,0.0,2.774935,0.00000,0.0,0.0,0.0,0.0,0.0,...,0.0,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.0,2.774935,0.0
AAATGCCAGACG-1-0,0.000000,0.000000,0.0,0.000000,0.00000,0.0,0.0,0.0,0.0,0.0,...,0.0,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.0,2.685205,0.0
AACCATGCACCG-1-0,2.407036,0.000000,0.0,3.054096,0.00000,0.0,0.0,0.0,0.0,0.0,...,0.0,0.000000,0.0,0.0,0.000000,2.407036,0.0,0.0,0.000000,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGATCTCGCG-9-1,0.000000,0.000000,0.0,0.000000,0.00000,0.0,0.0,0.0,0.0,0.0,...,0.0,0.000000,0.0,0.0,2.926803,0.000000,0.0,0.0,0.000000,0.0
TTTGGAGAGACA-9-1,0.000000,0.000000,0.0,0.000000,0.00000,0.0,0.0,0.0,0.0,0.0,...,0.0,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.0
TTTGGAGCATAG-9-1,0.000000,2.958570,0.0,0.000000,0.00000,0.0,0.0,0.0,0.0,0.0,...,0.0,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.0
TTTGTTGAGCTA-9-1,0.000000,2.989107,0.0,0.000000,0.00000,0.0,0.0,0.0,0.0,0.0,...,0.0,0.000000,0.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.0


In [73]:
# Explicit GPU Detection and Setup
def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9} GB")
    else:
        device = torch.device("cpu")
        print("No GPU available. Using CPU.")
    return device

In [74]:
device = get_device()
print(device)

Using GPU: NVIDIA TITAN Xp
GPU Memory: 12.788563968 GB
cuda


In [105]:
class AdvancedSingleCellVAE(nn.Module):
    def __init__(self, 
                 input_dim, 
                 latent_dim=50, 
                 hidden_dims=[512, 256], 
                 dropout_rate=0.3):
        """
        Advanced Variational Autoencoder for Single-Cell Data
        
        Args:
        - input_dim: Dimensionality of input features
        - latent_dim: Dimensionality of latent space
        - hidden_dims: List of hidden layer dimensions
        - dropout_rate: Dropout probability
        """
        super().__init__()
        
        # Input normalization
        self.input_norm = nn.BatchNorm1d(input_dim)
        
        # Encoder
        encoder_layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            encoder_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(dropout_rate)
            ])
            prev_dim = hidden_dim
        
        self.encoder = nn.Sequential(*encoder_layers)
        
        # Latent space layers
        self.fc_mu = nn.Linear(prev_dim, latent_dim)
        self.fc_var = nn.Linear(prev_dim, latent_dim)
        
        # Decoder
        decoder_layers = []
        prev_dim = latent_dim
        for hidden_dim in reversed(hidden_dims):
            decoder_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(dropout_rate)
            ])
            prev_dim = hidden_dim
        
        decoder_layers.append(nn.Linear(prev_dim, input_dim))
        self.decoder = nn.Sequential(*decoder_layers)
    
    def reparameterize(self, mu, log_var):
        """Reparameterization trick for sampling"""
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        """Forward pass through VAE"""
        x = self.input_norm(x)
        x_encoded = self.encoder(x)
        
        mu = self.fc_mu(x_encoded)
        log_var = self.fc_var(x_encoded)
        
        z = self.reparameterize(mu, log_var)
        x_reconstructed = self.decoder(z)
        
        return x_reconstructed, mu, log_var

In [106]:
def vae_loss_function(recon_x, x, mu, log_var, beta=1.0):
    """
    Variational Autoencoder Loss Function
    
    Args:
    - recon_x: Reconstructed input
    - x: Original input
    - mu: Latent space mean
    - log_var: Latent space log variance
    - beta: KL divergence weight
    
    Returns:
    Total loss
    """
    # Reconstruction loss
    recon_loss = nn.MSELoss(reduction='sum')(recon_x, x)
    
    # KL Divergence
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    
    return recon_loss + beta * kl_loss

In [107]:
def train_vae(
    data, 
    input_dim, 
    latent_dim=50, 
    epochs=300, 
    batch_size=256, 
    learning_rate=0.001, 
    beta=1.0
):
    """
    Train Variational Autoencoder
    
    Args:
    - data: Preprocessed training data
    - input_dim: Input feature dimensionality
    - latent_dim: Latent space dimensionality
    - epochs: Number of training epochs
    - batch_size: Training batch size
    - learning_rate: Optimizer learning rate
    - beta: KL divergence regularization weight
    
    Returns:
    Trained model and training history
    """
    # Initialize model
    model = AdvancedSingleCellVAE(
        input_dim=input_dim, 
        latent_dim=latent_dim
    )
    
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.5, 
        patience=10
    )
    
    # Training tracking
    train_losses = []
    val_losses = []
    
    # Data preparation
    X_train = data['train']
    X_test = data['test']
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        
        # Batch training
        for i in range(0, len(X_train), batch_size):
            batch = X_train[i:i+batch_size]
            
            optimizer.zero_grad()
            
            # Forward pass
            recon_batch, mu, log_var = model(batch)
            
            # Compute loss
            loss = vae_loss_function(recon_batch, batch, mu, log_var, beta)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
        
        # Validation
        model.eval()
        with torch.no_grad():
            recon_test, mu_test, log_var_test = model(X_test)
            val_loss = vae_loss_function(
                recon_test, X_test, mu_test, log_var_test, beta
            )
        
        # Tracking
        train_avg_loss = total_train_loss / len(X_train)
        train_losses.append(train_avg_loss)
        val_losses.append(val_loss.item())
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Periodic reporting
        if epoch % 10 == 0:
            print(f'Epoch {epoch}: Train Loss = {train_avg_loss:.4f}, '
                  f'Val Loss = {val_loss.item():.4f}')
    
    return model, train_losses, val_losses

In [108]:
def analyze_vae_results(model, data, train_losses, val_losses):
    """
    Visualize VAE training results and latent space
    
    Args:
    - model: Trained VAE model
    - data: Preprocessed data dictionary
    - train_losses: Training losses
    - val_losses: Validation losses
    """
    # Loss visualization
    plt.figure(figsize=(15, 5))
    
    plt.subplot(131)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Loss Progression')
    plt.legend()
    
    plt.subplot(132)
    plt.hist(train_losses, bins=30, alpha=0.5, label='Train')
    plt.hist(val_losses, bins=30, alpha=0.5, label='Validation')
    plt.title('Loss Distribution')
    plt.legend()
    
    plt.subplot(133)
    plt.plot(np.convolve(train_losses, np.ones(10)/10, mode='valid'), label='Train MA')
    plt.plot(np.convolve(val_losses, np.ones(10)/10, mode='valid'), label='Val MA')
    plt.title('Moving Averages')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Latent space visualization
    model.eval()
    with torch.no_grad():
        _, mu, _ = model(data['test'])
    
    # Dimensionality reduction
    from sklearn.manifold import TSNE
    tsne = TSNE(n_components=2)
    mu_tsne = tsne.fit_transform(mu.numpy())
    
    plt.figure(figsize=(10, 8))
    plt.scatter(mu_tsne[:, 0], mu_tsne[:, 1])
    plt.title('Latent Space Visualization')
    plt.xlabel('t-SNE 1')
    plt.ylabel('t-SNE 2')
    plt.show()

In [109]:
# Train VAE
model, train_losses, val_losses = train_vae(
    data=pdata,
    input_dim=adata.X.shape[1],
    latent_dim=50,
    epochs=300
)

KeyError: 'train'