In [1]:
import sys
sys.path.append("../src")


In [2]:
import torch
import scanpy as sc
from torch.utils.data import DataLoader, random_split
from torch.optim import Adam

# local
from models import SparseAutoencoder
from dataset import EmbeddingDataset

In [3]:
adata = sc.read_h5ad("../data/external/adata_sample.h5ad")
embeddings_data = adata.obsm["geneformer"]

In [4]:
dset = EmbeddingDataset(embeddings_data)

In [5]:
model = SparseAutoencoder(512, 1024, expanded_ratio=4)

In [6]:
train_dataset, test_dataset = random_split(dset, [0.8, 0.2])

In [7]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [9]:
n_epoches = 1

optimizer = Adam(model.parameters(), lr=1e-3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)
model.train()

for i_epoch in range(n_epoches):
    for i_step, d in enumerate(train_loader):
        d = d.to(device)
        optimizer.zero_grad()
        recon, embed = model(d)
        recon_loss, sparsity_loss, total_loss = model.get_total_loss(d, recon, embed)
        total_loss.backward()
        optimizer.step()

        if i_step % 100:
            print(f"epoches {i_epoch}, steps {i_step}, loss (recon_loss): {recon_loss.item()}, loss (sparsity): {sparsity_loss.item()}, loss (total): {total_loss.item()}")


epoches 0, steps 1, loss (recon_loss): 0.2937175929546356, loss (sparsity): 0.10660004615783691, loss (total): 0.2938241958618164
epoches 0, steps 2, loss (recon_loss): 0.08938649296760559, loss (sparsity): 0.07742805778980255, loss (total): 0.08946391940116882
epoches 0, steps 3, loss (recon_loss): 0.10133455693721771, loss (sparsity): 0.07861055433750153, loss (total): 0.10141316801309586
epoches 0, steps 4, loss (recon_loss): 0.09201420098543167, loss (sparsity): 0.08482982963323593, loss (total): 0.09209903329610825
epoches 0, steps 5, loss (recon_loss): 0.10279570519924164, loss (sparsity): 0.07272199541330338, loss (total): 0.10286843031644821
epoches 0, steps 6, loss (recon_loss): 0.07029101997613907, loss (sparsity): 0.0804959088563919, loss (total): 0.07037151604890823
epoches 0, steps 7, loss (recon_loss): 0.07630794495344162, loss (sparsity): 0.08493424952030182, loss (total): 0.07639288157224655
epoches 0, steps 8, loss (recon_loss): 0.046384297311306, loss (sparsity): 0.08