# Implementation

In [2]:
import numpy as np
import argparse
from data_loader import load_data, preprocess
from model import Autoencoder, ClusteringLayer 
from trainer import train, get_embeddings
from cluster import evaluate_cluster
import torch

def main(filepath, file_format, n_clusters, encoding_dim, epochs, learning_rate, batch_size, normalize, scale, log_transform, n_top_genes,pretrain_epochs=200,ae_weights=None):
    #load and preprocess Data
    print("Loading and Preprocessing Data")
    data, labels = load_data(filepath, format=file_format)
    if data is None:
        print("Failed to load data. Exiting.")
        return
    processed_data, scaler = preprocess(data, normalize, scale, log_transform, n_top_genes)
    input_dim = processed_data.shape[1]

    #define and train Model
    print("Defining and Training Model")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    autoencoder = Autoencoder(input_dim, encoding_dim).to(device)
    clustering_layer = ClusteringLayer(n_clusters, encoding_dim).to(device)
    
    #pretrain autoencoder
    if ae_weights is None: #pretraining will be skipped if ae_weights is provided
        print("Pretraining autoencoder...")
        pretrain_optimizer = torch.optim.Adam(autoencoder.parameters(), lr=learning_rate) #adam optimizer
        pretrain_criterion = torch.nn.MSELoss()
        autoencoder.train() #set to train mode
        for epoch in range(pretrain_epochs):
            for batch in torch.utils.data.DataLoader(torch.tensor(processed_data, dtype=torch.float32).to(device), batch_size=batch_size, shuffle=True): #added explicit processed data to device
                pretrain_optimizer.zero_grad()
                _, decoded = autoencoder(batch)
                loss = pretrain_criterion(decoded, batch)
                loss.backward()
                pretrain_optimizer.step()
            if (epoch+1) % 10 == 0:
                print(f"Pretrain Epoch: {epoch+1}, Loss: {loss.item()}")
        #torch.save(autoencoder.state_dict(), "ae_weights.pt") #saving weights for testing purposes
    else:
        autoencoder.load_state_dict(torch.load(ae_weights,map_location=torch.device('cpu'))) #added map location for device agnostic saving/loading
    
    #train jointly
    trained_model, trained_clustering_layer = train(autoencoder, clustering_layer, processed_data, epochs, batch_size, learning_rate, n_clusters, device)

    #get embeddings using the trained autoencoder
    print("Obtaining Embeddings")
    embeddings = get_embeddings(trained_model, processed_data)
    
    #clustering using the trained clustering layer
    print("Clustering")
    cluster_labels = trained_clustering_layer(torch.tensor(embeddings, dtype=torch.float32).to(device)).argmax(1).cpu().numpy()
    
    #eval
    if labels is not None:  #real labels if available
        true_labels = labels
    else:
        true_labels = np.random.randint(0, n_clusters, embeddings.shape[0]) #otherwise use placeholder

    evaluation_results = evaluate_cluster(true_labels, cluster_labels)
    print(f"Clustering Evaluation: ARI={evaluation_results['ARI']:.4f}, NMI={evaluation_results['NMI']:.4f}")
    return cluster_labels, embeddings, scaler

if __name__ == "__main__":
    filepath = 'scDeepClustering_Sample_Data/mouse_bladder_cell_select_2100.h5'
    # filepath = 'Splatter_Sim_Data/splatter_simulate_data_1.h5'
    file_format = 'h5'
    n_clusters = 3
    encoding_dim = 32
    epochs = 100
    learning_rate = 0.001
    batch_size = 32
    normalize = True
    scale = True
    log_transform = True
    n_top_genes = 2000
    pretrain_epochs = 100 #can adjust
    ae_weights = None #or "ae_weights.pt" #uncomment to use saved weights rather than pretraining every time
    main(filepath, file_format, n_clusters, encoding_dim, epochs, learning_rate, batch_size, normalize, scale, log_transform, n_top_genes,pretrain_epochs,ae_weights) #pass to main function

Loading and Preprocessing Data
Keys in file: ['X', 'Y']
=== DATA: ===
[[  9   0   0 ...   2 253  18]
 [  9   0   4 ...   5 268  13]
 [ 10   0   0 ...   0 286  45]
 ...
 [ 20   0   2 ...   1 200  21]
 [ 19   4   0 ...   2 211  19]
 [  0   0   0 ...   2 170  21]]


=== DATA AFTER: ===
[[ 1.37633088 -0.07424526 -0.53139794 ... -0.82777124  0.13916345
  -0.08761217]
 [ 0.15100039  0.32316539 -0.53139794 ... -1.31165484 -0.6785124
   0.86436923]
 [-1.05594923 -0.77410053 -0.53139794 ... -2.02494913 -0.82194957
  -0.09430186]
 ...
 [ 0.70618508 -0.11371273 -0.53139794 ...  0.49143894 -1.80315314
   0.9931173 ]
 [-0.91948786  0.07881428 -0.53139794 ...  1.539124   -0.64641285
  -1.82330624]
 [ 0.59652344 -0.40862739  0.94749624 ...  0.87469364 -0.12875263
   0.19072659]]


Defining and Training Model
Pretraining autoencoder...
Pretrain Epoch: 10, Loss: 0.8773766756057739
Pretrain Epoch: 20, Loss: 0.9208022952079773
Pretrain Epoch: 30, Loss: 0.8740976452827454
Pretrain Epoch: 40, Loss: 0.85439