In [None]:
# ==============================================================================
# Script:           train_betaVAE.py
# Purpose:          Entry-point to train a betaVAE using a pancancer dataset
# Author:           Sophia Li
# Affiliation:      CCG Lab, Princess Margaret Cancer Center, UHN, UofT
# Date:             12/31/2025
#
# Configurations:   betaVAE.yaml
#
# Notes:            Begins an Optuna hyperparameter sweep
# ==============================================================================

In [None]:
from MethylCDM.data.methylation_datamodule import MethylDataModule
from MethylCDM.models.betaVAE import BetaVAE

import torch

%load_ext autoreload
%autoreload 2

In [None]:
dm = MethylDataModule(
    train_adata_path = "/Volumes/FBI_Drive/MethylCDM-project/data/training/methylation/tcga_train_gene_matrix.h5ad",
    val_adata_path = "/Volumes/FBI_Drive/MethylCDM-project/data/training/methylation/tcga_val_gene_matrix.h5ad",
    test_adata_path = "/Volumes/FBI_Drive/MethylCDM-project/data/training/methylation/tcga_test_gene_matrix.h5ad",
    batch_size = 128,
    num_workers = 4
)

In [None]:
dm.setup()
batch = next(iter(dm.train_dataloader()))
print(batch["methylation_data"].shape)

In [None]:
batch["methylation_data"].shape[1]

In [None]:
model = BetaVAE(
    input_dim = batch["methylation_data"].shape[1],
    latent_dim = 200,
    encoder_dims = [6000, 4000],
    decoder_dims = [4000, 6000],
    beta = 0.005,
    lr = 0.003
)


In [None]:
with torch.no_grad():
    x_hat, mu, logvar = model(batch["methylation_data"])

In [None]:
print(x_hat.shape, mu.shape, logvar.shape)


In [None]:
import pytorch_lightning as pl
trainer = pl.Trainer(
    max_epochs=5,
    accelerator="gpu",  # or "cpu"
    devices=1,
    log_every_n_steps=10,
    deterministic=True,
)

